From fe4c0aa890599cbd60250d9476f527c24da38bb6 Mon Sep 17 00:00:00 2001 From: Michael Peter Christen Date: Tue, 21 May 2024 00:06:19 +0200 Subject: [PATCH] refactoring of RAG reverse proxy: extracted code for ollama code to their own classes --- source/net/yacy/ai/OllamaClient.java | 117 ++++++++ source/net/yacy/ai/OpenAIClient.java | 165 +++++++++++ .../yacy/http/servlets/RAGProxyServlet.java | 280 +++--------------- 3 files changed, 330 insertions(+), 232 deletions(-) create mode 100644 source/net/yacy/ai/OllamaClient.java create mode 100644 source/net/yacy/ai/OpenAIClient.java diff --git a/source/net/yacy/ai/OllamaClient.java b/source/net/yacy/ai/OllamaClient.java new file mode 100644 index 000000000..742645fee --- /dev/null +++ b/source/net/yacy/ai/OllamaClient.java @@ -0,0 +1,117 @@ +/** + * OllamaClient + * Copyright 2024 by Michael Peter Christen + * First released 17.05.2024 at http://yacy.net + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program in the file lgpl21.txt + * If not, see . + */ + +package net.yacy.ai; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +public class OllamaClient { + + public static String OLLAMA_API_HOST = "http://localhost:11434"; + + private String hoststub; + + public OllamaClient(String hoststub) { + this.hoststub = hoststub; + } + + public LinkedHashMap listOllamaModels() { + LinkedHashMap sortedMap = new LinkedHashMap<>(); + try { + String response = OpenAIClient.sendGetRequest(this.hoststub + "/api/tags"); + JSONObject responseObject = new JSONObject(response); + JSONArray models = responseObject.getJSONArray("models"); + + List> list = new ArrayList<>(); + for (int i = 0; i < models.length(); i++) { + JSONObject model = models.getJSONObject(i); + String name = model.optString("name", ""); + long size = model.optLong("size", 0); + list.add(new AbstractMap.SimpleEntry(name, size)); + } + + // Sort the list in descending order based on the values + list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue())); + + // Create a new LinkedHashMap and add the sorted entries + for (Map.Entry entry : list) { + sortedMap.put(entry.getKey(), entry.getValue()); + } + } catch (JSONException | URISyntaxException | IOException e) { + e.printStackTrace(); + } + return sortedMap; + } + + public boolean ollamaModelExists(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + OpenAIClient.sendPostRequest(this.hoststub + "/api/show", data); + return true; + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public boolean pullOllamaModel(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + data.put("stream", false); + String response = OpenAIClient.sendPostRequest(this.hoststub + "/api/pull", data); + // this sends {"status": "success"} in case of success + JSONObject responseObject = new JSONObject(response); + String status = responseObject.optString("status", ""); + return status.equals("success"); + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public static void main(String[] args) { + OllamaClient oc = new OllamaClient(OLLAMA_API_HOST); + + LinkedHashMap models = oc.listOllamaModels(); + System.out.println(models.toString()); + + // check if model exists + String model = "phi3:3.8b"; + if (oc.ollamaModelExists(model)) + System.out.println("model " + model + " exists"); + else + System.out.println("model " + model + " does not exist"); + + // pull a model + boolean success = oc.pullOllamaModel(model); + System.out.println("pulled model + " + model + ": " + success); + + } +} diff --git a/source/net/yacy/ai/OpenAIClient.java b/source/net/yacy/ai/OpenAIClient.java new file mode 100644 index 000000000..d823426ec --- /dev/null +++ b/source/net/yacy/ai/OpenAIClient.java @@ -0,0 +1,165 @@ +/** + * OpenAIClient + * Copyright 2024 by Michael Peter Christen + * First released 17.05.2024 at http://yacy.net + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program in the file lgpl21.txt + * If not, see . + */ + +package net.yacy.ai; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + + +public class OpenAIClient { + + private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"}; + + private String hoststub; + + public OpenAIClient(String hoststub) { + this.hoststub = hoststub; + } + + + // API Helper Methods + + public static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setDoOutput(true); + + try (OutputStream os = conn.getOutputStream()) { + byte[] input = data.toString().getBytes("utf-8"); + os.write(input, 0, input.length); + } + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + public static String sendGetRequest(String endpoint) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("GET"); + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + // OpenAI chat client, works with llama.cpp and Ollama + + public String chat(String model, String prompt, int max_tokens) throws IOException { + JSONObject data = new JSONObject(); + JSONArray messages = new JSONArray(); + JSONObject systemPrompt = new JSONObject(true); + JSONObject userPrompt = new JSONObject(true); + messages.put(systemPrompt); + messages.put(userPrompt); + try { + systemPrompt.put("role", "system"); + systemPrompt.put("content", "Make short answers."); + userPrompt.put("role", "user"); + userPrompt.put("content", prompt); + data.put("model", model); + data.put("temperature", 0.1); + data.put("max_tokens", max_tokens); + data.put("messages", messages); + data.put("stop", new JSONArray(STOPTOKENS)); + data.put("stream", false); + String response = sendPostRequest(this.hoststub + "/v1/chat/completions", data); + JSONObject responseObject = new JSONObject(response); + JSONArray choices = responseObject.getJSONArray("choices"); + JSONObject choice = choices.getJSONObject(0); + JSONObject message = choice.getJSONObject("message"); + String content = message.optString("content", ""); + return content; + } catch (JSONException | URISyntaxException e) { + throw new IOException(e.getMessage()); + } + } + + public static String[] stringsFromChat(String answer) { + int p = answer.indexOf('['); + int q = answer.indexOf(']'); + if (p < 0 || q < 0 || q < p) return new String[0]; + try { + JSONArray a = new JSONArray(answer.substring(p, q + 1)); + String[] arr = new String[a.length()]; + for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i); + return arr; + } catch (JSONException e) { + return new String[0]; + } + } + + public static void main(String[] args) { + String model = "phi3:3.8b"; + OpenAIClient oaic = new OpenAIClient(OllamaClient.OLLAMA_API_HOST); + // make chat completion with model + String question = "Who invented the wheel?"; + try { + String answer = oaic.chat(model, question, 80); + System.out.println(answer); + } catch (IOException e) { + e.printStackTrace(); + } + + // try the json parser from chat results + question = "Make a list of four names from Star Wars movies. Use a JSON Array."; + try { + String[] a = stringsFromChat(oaic.chat(model, question, 80)); + for (String s: a) System.out.println(s); + } catch (IOException e) { + e.printStackTrace(); + } + } + +} diff --git a/source/net/yacy/http/servlets/RAGProxyServlet.java b/source/net/yacy/http/servlets/RAGProxyServlet.java index 573c1cc2d..aa5c708ab 100644 --- a/source/net/yacy/http/servlets/RAGProxyServlet.java +++ b/source/net/yacy/http/servlets/RAGProxyServlet.java @@ -24,6 +24,8 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import net.yacy.ai.OllamaClient; +import net.yacy.ai.OpenAIClient; import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector; import net.yacy.search.Switchboard; import net.yacy.search.schema.CollectionSchema; @@ -63,7 +65,6 @@ import java.util.Map; public class RAGProxyServlet extends HttpServlet { private static final long serialVersionUID = 3411544789759603107L; - private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"}; private static Boolean LLM_ENABLED = false; private static Boolean LLM_CONTROL_OLLAMA = true; @@ -118,7 +119,6 @@ public class RAGProxyServlet extends HttpServlet { try { // get system message and user prompt bodyObject = new JSONObject(body); - String model = bodyObject.optString("model", LLM_ANSWER_MODEL); // we need a switch to allow overwriting JSONArray messages = bodyObject.optJSONArray("messages"); JSONObject systemObject = messages.getJSONObject(0); String system = systemObject.optString("content", ""); // the system prompt @@ -161,7 +161,7 @@ public class RAGProxyServlet extends HttpServlet { // write back response of the back-end service to the client; use status of backend-response int status = conn.getResponseCode(); - String rmessage = conn.getResponseMessage(); + //String rmessage = conn.getResponseMessage(); hresponse.setStatus(status); if (status == 200) { @@ -179,6 +179,51 @@ public class RAGProxyServlet extends HttpServlet { throw new IOException(e.getMessage()); } } + + public static LinkedHashMap searchResults(String query, int count) { + Switchboard sb = Switchboard.getSwitchboard(); + EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector(); + // construct query + final SolrQuery params = new SolrQuery(); + params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query); + params.setRows(count); + params.setStart(0); + params.setFacet(false); + params.clearSorts(); + params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName()); + params.setIncludeScore(false); + params.set("df", CollectionSchema.text_t.getSolrFieldName()); + + // query the server + try { + final SolrDocumentList sdl = connector.getDocumentListByParams(params); + LinkedHashMap a = new LinkedHashMap(); + Iterator i = sdl.iterator(); + while (i.hasNext()) { + SolrDocument doc = i.next(); + String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName()); + String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName()); + a.put(url, text); + } + return a; + } catch (SolrException | IOException e) { + return new LinkedHashMap(); + } + } + + private String searchWordsForPrompt(String model, String prompt) { + StringBuilder query = new StringBuilder(); + String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt; + try { + OpenAIClient oaic = new OpenAIClient(LLM_API_HOST); + String[] a = OpenAIClient.stringsFromChat(oaic.chat(model, question, 80)); + for (String s: a) query.append(s).append(' '); + return query.toString().trim(); + } catch (IOException e) { + e.printStackTrace(); + return ""; + } + } private static JSONObject responseLine(String payload) { JSONObject j = new JSONObject(true); @@ -201,234 +246,5 @@ public class RAGProxyServlet extends HttpServlet { } catch (JSONException e) {} return j; } - - // API Helper Methods for Ollama - - private static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException { - URL url = new URI(endpoint).toURL(); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("POST"); - conn.setRequestProperty("Content-Type", "application/json"); - conn.setDoOutput(true); - - try (OutputStream os = conn.getOutputStream()) { - byte[] input = data.toString().getBytes("utf-8"); - os.write(input, 0, input.length); - } - - int responseCode = conn.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { - StringBuilder response = new StringBuilder(); - String responseLine; - while ((responseLine = br.readLine()) != null) { - response.append(responseLine.trim()); - } - return response.toString(); - } - } else { - throw new IOException("Request failed with response code " + responseCode); - } - } - - private static String sendGetRequest(String endpoint) throws IOException, URISyntaxException { - URL url = new URI(endpoint).toURL(); - HttpURLConnection conn = (HttpURLConnection) url.openConnection(); - conn.setRequestMethod("GET"); - - int responseCode = conn.getResponseCode(); - if (responseCode == HttpURLConnection.HTTP_OK) { - try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { - StringBuilder response = new StringBuilder(); - String responseLine; - while ((responseLine = br.readLine()) != null) { - response.append(responseLine.trim()); - } - return response.toString(); - } - } else { - throw new IOException("Request failed with response code " + responseCode); - } - } - - // OpenAI chat client, works also with llama.cpp and Ollama - - public static String chat(String model, String prompt, int max_tokens) throws IOException { - JSONObject data = new JSONObject(); - JSONArray messages = new JSONArray(); - JSONObject systemPrompt = new JSONObject(true); - JSONObject userPrompt = new JSONObject(true); - messages.put(systemPrompt); - messages.put(userPrompt); - try { - systemPrompt.put("role", "system"); - systemPrompt.put("content", "Make short answers."); - userPrompt.put("role", "user"); - userPrompt.put("content", prompt); - data.put("model", model); - data.put("temperature", 0.1); - data.put("max_tokens", max_tokens); - data.put("messages", messages); - data.put("stop", new JSONArray(STOPTOKENS)); - data.put("stream", false); - String response = sendPostRequest(LLM_API_HOST + "/v1/chat/completions", data); - JSONObject responseObject = new JSONObject(response); - JSONArray choices = responseObject.getJSONArray("choices"); - JSONObject choice = choices.getJSONObject(0); - JSONObject message = choice.getJSONObject("message"); - String content = message.optString("content", ""); - return content; - } catch (JSONException | URISyntaxException e) { - throw new IOException(e.getMessage()); - } - } - - public static String[] stringsFromChat(String answer) { - int p = answer.indexOf('['); - int q = answer.indexOf(']'); - if (p < 0 || q < 0 || q < p) return new String[0]; - try { - JSONArray a = new JSONArray(answer.substring(p, q + 1)); - String[] arr = new String[a.length()]; - for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i); - return arr; - } catch (JSONException e) { - return new String[0]; - } - } - - private static String searchWordsForPrompt(String model, String prompt) { - StringBuilder query = new StringBuilder(); - String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt; - try { - String[] a = stringsFromChat(chat(model, question, 80)); - for (String s: a) query.append(s).append(' '); - return query.toString().trim(); - } catch (IOException e) { - e.printStackTrace(); - return ""; - } - } - - private static LinkedHashMap searchResults(String query, int count) { - Switchboard sb = Switchboard.getSwitchboard(); - EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector(); - // construct query - final SolrQuery params = new SolrQuery(); - params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query); - params.setRows(count); - params.setStart(0); - params.setFacet(false); - params.clearSorts(); - params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName()); - params.setIncludeScore(false); - params.set("df", CollectionSchema.text_t.getSolrFieldName()); - - // query the server - try { - final SolrDocumentList sdl = connector.getDocumentListByParams(params); - LinkedHashMap a = new LinkedHashMap(); - Iterator i = sdl.iterator(); - while (i.hasNext()) { - SolrDocument doc = i.next(); - String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName()); - String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName()); - a.put(url, text); - } - return a; - } catch (SolrException | IOException e) { - return new LinkedHashMap(); - } - } - - // Ollama client functions - - public static LinkedHashMap listOllamaModels() { - LinkedHashMap sortedMap = new LinkedHashMap<>(); - try { - String response = sendGetRequest(LLM_API_HOST + "/api/tags"); - JSONObject responseObject = new JSONObject(response); - JSONArray models = responseObject.getJSONArray("models"); - - List> list = new ArrayList<>(); - for (int i = 0; i < models.length(); i++) { - JSONObject model = models.getJSONObject(i); - String name = model.optString("name", ""); - long size = model.optLong("size", 0); - list.add(new AbstractMap.SimpleEntry(name, size)); - } - - // Sort the list in descending order based on the values - list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue())); - - // Create a new LinkedHashMap and add the sorted entries - for (Map.Entry entry : list) { - sortedMap.put(entry.getKey(), entry.getValue()); - } - } catch (JSONException | URISyntaxException | IOException e) { - e.printStackTrace(); - } - return sortedMap; - } - - public static boolean ollamaModelExists(String name) { - JSONObject data = new JSONObject(); - try { - data.put("name", name); - sendPostRequest(LLM_API_HOST + "/api/show", data); - return true; - } catch (JSONException | URISyntaxException | IOException e) { - return false; - } - } - public static boolean pullOllamaModel(String name) { - JSONObject data = new JSONObject(); - try { - data.put("name", name); - data.put("stream", false); - String response = sendPostRequest(LLM_API_HOST + "/api/pull", data); - // this sends {"status": "success"} in case of success - JSONObject responseObject = new JSONObject(response); - String status = responseObject.optString("status", ""); - return status.equals("success"); - } catch (JSONException | URISyntaxException | IOException e) { - return false; - } - } - - public static void main(String[] args) { - LinkedHashMap models = listOllamaModels(); - System.out.println(models.toString()); - - // check if model exists - //String model = "phi3:3.8b"; - String model = "gemma:2b"; - if (ollamaModelExists(model)) - System.out.println("model " + model + " exists"); - else - System.out.println("model " + model + " does not exist"); - - // pull a model - boolean success = pullOllamaModel(model); - System.out.println("pulled model + " + model + ": " + success); - - // make chat completion with model - String question = "Who invented the wheel?"; - try { - String answer = chat(model, question, 80); - System.out.println(answer); - } catch (IOException e) { - e.printStackTrace(); - } - - // try the json parser from chat results - question = "Make a list of four names from Star Wars movies. Use a JSON Array."; - try { - String[] a = stringsFromChat(chat(model, question, 80)); - for (String s: a) System.out.println(s); - } catch (IOException e) { - e.printStackTrace(); - } - } }