diff --git a/source/net/yacy/ai/OllamaClient.java b/source/net/yacy/ai/OllamaClient.java
new file mode 100644
index 000000000..742645fee
--- /dev/null
+++ b/source/net/yacy/ai/OllamaClient.java
@@ -0,0 +1,117 @@
+/**
+ * OllamaClient
+ * Copyright 2024 by Michael Peter Christen
+ * First released 17.05.2024 at http://yacy.net
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this program in the file lgpl21.txt
+ * If not, see .
+ */
+
+package net.yacy.ai;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.json.JSONArray;
+import org.json.JSONException;
+import org.json.JSONObject;
+
+public class OllamaClient {
+
+ public static String OLLAMA_API_HOST = "http://localhost:11434";
+
+ private String hoststub;
+
+ public OllamaClient(String hoststub) {
+ this.hoststub = hoststub;
+ }
+
+ public LinkedHashMap listOllamaModels() {
+ LinkedHashMap sortedMap = new LinkedHashMap<>();
+ try {
+ String response = OpenAIClient.sendGetRequest(this.hoststub + "/api/tags");
+ JSONObject responseObject = new JSONObject(response);
+ JSONArray models = responseObject.getJSONArray("models");
+
+ List> list = new ArrayList<>();
+ for (int i = 0; i < models.length(); i++) {
+ JSONObject model = models.getJSONObject(i);
+ String name = model.optString("name", "");
+ long size = model.optLong("size", 0);
+ list.add(new AbstractMap.SimpleEntry(name, size));
+ }
+
+ // Sort the list in descending order based on the values
+ list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
+
+ // Create a new LinkedHashMap and add the sorted entries
+ for (Map.Entry entry : list) {
+ sortedMap.put(entry.getKey(), entry.getValue());
+ }
+ } catch (JSONException | URISyntaxException | IOException e) {
+ e.printStackTrace();
+ }
+ return sortedMap;
+ }
+
+ public boolean ollamaModelExists(String name) {
+ JSONObject data = new JSONObject();
+ try {
+ data.put("name", name);
+ OpenAIClient.sendPostRequest(this.hoststub + "/api/show", data);
+ return true;
+ } catch (JSONException | URISyntaxException | IOException e) {
+ return false;
+ }
+ }
+
+ public boolean pullOllamaModel(String name) {
+ JSONObject data = new JSONObject();
+ try {
+ data.put("name", name);
+ data.put("stream", false);
+ String response = OpenAIClient.sendPostRequest(this.hoststub + "/api/pull", data);
+ // this sends {"status": "success"} in case of success
+ JSONObject responseObject = new JSONObject(response);
+ String status = responseObject.optString("status", "");
+ return status.equals("success");
+ } catch (JSONException | URISyntaxException | IOException e) {
+ return false;
+ }
+ }
+
+ public static void main(String[] args) {
+ OllamaClient oc = new OllamaClient(OLLAMA_API_HOST);
+
+ LinkedHashMap models = oc.listOllamaModels();
+ System.out.println(models.toString());
+
+ // check if model exists
+ String model = "phi3:3.8b";
+ if (oc.ollamaModelExists(model))
+ System.out.println("model " + model + " exists");
+ else
+ System.out.println("model " + model + " does not exist");
+
+ // pull a model
+ boolean success = oc.pullOllamaModel(model);
+ System.out.println("pulled model + " + model + ": " + success);
+
+ }
+}
diff --git a/source/net/yacy/ai/OpenAIClient.java b/source/net/yacy/ai/OpenAIClient.java
new file mode 100644
index 000000000..d823426ec
--- /dev/null
+++ b/source/net/yacy/ai/OpenAIClient.java
@@ -0,0 +1,165 @@
+/**
+ * OpenAIClient
+ * Copyright 2024 by Michael Peter Christen
+ * First released 17.05.2024 at http://yacy.net
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this program in the file lgpl21.txt
+ * If not, see .
+ */
+
+package net.yacy.ai;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.URL;
+
+import org.json.JSONArray;
+import org.json.JSONException;
+import org.json.JSONObject;
+
+
+public class OpenAIClient {
+
+ private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"};
+
+ private String hoststub;
+
+ public OpenAIClient(String hoststub) {
+ this.hoststub = hoststub;
+ }
+
+
+ // API Helper Methods
+
+ public static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException {
+ URL url = new URI(endpoint).toURL();
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestMethod("POST");
+ conn.setRequestProperty("Content-Type", "application/json");
+ conn.setDoOutput(true);
+
+ try (OutputStream os = conn.getOutputStream()) {
+ byte[] input = data.toString().getBytes("utf-8");
+ os.write(input, 0, input.length);
+ }
+
+ int responseCode = conn.getResponseCode();
+ if (responseCode == HttpURLConnection.HTTP_OK) {
+ try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
+ StringBuilder response = new StringBuilder();
+ String responseLine;
+ while ((responseLine = br.readLine()) != null) {
+ response.append(responseLine.trim());
+ }
+ return response.toString();
+ }
+ } else {
+ throw new IOException("Request failed with response code " + responseCode);
+ }
+ }
+
+ public static String sendGetRequest(String endpoint) throws IOException, URISyntaxException {
+ URL url = new URI(endpoint).toURL();
+ HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+ conn.setRequestMethod("GET");
+
+ int responseCode = conn.getResponseCode();
+ if (responseCode == HttpURLConnection.HTTP_OK) {
+ try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
+ StringBuilder response = new StringBuilder();
+ String responseLine;
+ while ((responseLine = br.readLine()) != null) {
+ response.append(responseLine.trim());
+ }
+ return response.toString();
+ }
+ } else {
+ throw new IOException("Request failed with response code " + responseCode);
+ }
+ }
+
+ // OpenAI chat client, works with llama.cpp and Ollama
+
+ public String chat(String model, String prompt, int max_tokens) throws IOException {
+ JSONObject data = new JSONObject();
+ JSONArray messages = new JSONArray();
+ JSONObject systemPrompt = new JSONObject(true);
+ JSONObject userPrompt = new JSONObject(true);
+ messages.put(systemPrompt);
+ messages.put(userPrompt);
+ try {
+ systemPrompt.put("role", "system");
+ systemPrompt.put("content", "Make short answers.");
+ userPrompt.put("role", "user");
+ userPrompt.put("content", prompt);
+ data.put("model", model);
+ data.put("temperature", 0.1);
+ data.put("max_tokens", max_tokens);
+ data.put("messages", messages);
+ data.put("stop", new JSONArray(STOPTOKENS));
+ data.put("stream", false);
+ String response = sendPostRequest(this.hoststub + "/v1/chat/completions", data);
+ JSONObject responseObject = new JSONObject(response);
+ JSONArray choices = responseObject.getJSONArray("choices");
+ JSONObject choice = choices.getJSONObject(0);
+ JSONObject message = choice.getJSONObject("message");
+ String content = message.optString("content", "");
+ return content;
+ } catch (JSONException | URISyntaxException e) {
+ throw new IOException(e.getMessage());
+ }
+ }
+
+ public static String[] stringsFromChat(String answer) {
+ int p = answer.indexOf('[');
+ int q = answer.indexOf(']');
+ if (p < 0 || q < 0 || q < p) return new String[0];
+ try {
+ JSONArray a = new JSONArray(answer.substring(p, q + 1));
+ String[] arr = new String[a.length()];
+ for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i);
+ return arr;
+ } catch (JSONException e) {
+ return new String[0];
+ }
+ }
+
+ public static void main(String[] args) {
+ String model = "phi3:3.8b";
+ OpenAIClient oaic = new OpenAIClient(OllamaClient.OLLAMA_API_HOST);
+ // make chat completion with model
+ String question = "Who invented the wheel?";
+ try {
+ String answer = oaic.chat(model, question, 80);
+ System.out.println(answer);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ // try the json parser from chat results
+ question = "Make a list of four names from Star Wars movies. Use a JSON Array.";
+ try {
+ String[] a = stringsFromChat(oaic.chat(model, question, 80));
+ for (String s: a) System.out.println(s);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+}
diff --git a/source/net/yacy/http/servlets/RAGProxyServlet.java b/source/net/yacy/http/servlets/RAGProxyServlet.java
index 573c1cc2d..aa5c708ab 100644
--- a/source/net/yacy/http/servlets/RAGProxyServlet.java
+++ b/source/net/yacy/http/servlets/RAGProxyServlet.java
@@ -24,6 +24,8 @@ import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
+import net.yacy.ai.OllamaClient;
+import net.yacy.ai.OpenAIClient;
import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector;
import net.yacy.search.Switchboard;
import net.yacy.search.schema.CollectionSchema;
@@ -63,7 +65,6 @@ import java.util.Map;
public class RAGProxyServlet extends HttpServlet {
private static final long serialVersionUID = 3411544789759603107L;
- private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"};
private static Boolean LLM_ENABLED = false;
private static Boolean LLM_CONTROL_OLLAMA = true;
@@ -118,7 +119,6 @@ public class RAGProxyServlet extends HttpServlet {
try {
// get system message and user prompt
bodyObject = new JSONObject(body);
- String model = bodyObject.optString("model", LLM_ANSWER_MODEL); // we need a switch to allow overwriting
JSONArray messages = bodyObject.optJSONArray("messages");
JSONObject systemObject = messages.getJSONObject(0);
String system = systemObject.optString("content", ""); // the system prompt
@@ -161,7 +161,7 @@ public class RAGProxyServlet extends HttpServlet {
// write back response of the back-end service to the client; use status of backend-response
int status = conn.getResponseCode();
- String rmessage = conn.getResponseMessage();
+ //String rmessage = conn.getResponseMessage();
hresponse.setStatus(status);
if (status == 200) {
@@ -179,6 +179,51 @@ public class RAGProxyServlet extends HttpServlet {
throw new IOException(e.getMessage());
}
}
+
+ public static LinkedHashMap searchResults(String query, int count) {
+ Switchboard sb = Switchboard.getSwitchboard();
+ EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector();
+ // construct query
+ final SolrQuery params = new SolrQuery();
+ params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query);
+ params.setRows(count);
+ params.setStart(0);
+ params.setFacet(false);
+ params.clearSorts();
+ params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName());
+ params.setIncludeScore(false);
+ params.set("df", CollectionSchema.text_t.getSolrFieldName());
+
+ // query the server
+ try {
+ final SolrDocumentList sdl = connector.getDocumentListByParams(params);
+ LinkedHashMap a = new LinkedHashMap();
+ Iterator i = sdl.iterator();
+ while (i.hasNext()) {
+ SolrDocument doc = i.next();
+ String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName());
+ String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName());
+ a.put(url, text);
+ }
+ return a;
+ } catch (SolrException | IOException e) {
+ return new LinkedHashMap();
+ }
+ }
+
+ private String searchWordsForPrompt(String model, String prompt) {
+ StringBuilder query = new StringBuilder();
+ String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt;
+ try {
+ OpenAIClient oaic = new OpenAIClient(LLM_API_HOST);
+ String[] a = OpenAIClient.stringsFromChat(oaic.chat(model, question, 80));
+ for (String s: a) query.append(s).append(' ');
+ return query.toString().trim();
+ } catch (IOException e) {
+ e.printStackTrace();
+ return "";
+ }
+ }
private static JSONObject responseLine(String payload) {
JSONObject j = new JSONObject(true);
@@ -201,234 +246,5 @@ public class RAGProxyServlet extends HttpServlet {
} catch (JSONException e) {}
return j;
}
-
- // API Helper Methods for Ollama
-
- private static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException {
- URL url = new URI(endpoint).toURL();
- HttpURLConnection conn = (HttpURLConnection) url.openConnection();
- conn.setRequestMethod("POST");
- conn.setRequestProperty("Content-Type", "application/json");
- conn.setDoOutput(true);
-
- try (OutputStream os = conn.getOutputStream()) {
- byte[] input = data.toString().getBytes("utf-8");
- os.write(input, 0, input.length);
- }
-
- int responseCode = conn.getResponseCode();
- if (responseCode == HttpURLConnection.HTTP_OK) {
- try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
- StringBuilder response = new StringBuilder();
- String responseLine;
- while ((responseLine = br.readLine()) != null) {
- response.append(responseLine.trim());
- }
- return response.toString();
- }
- } else {
- throw new IOException("Request failed with response code " + responseCode);
- }
- }
-
- private static String sendGetRequest(String endpoint) throws IOException, URISyntaxException {
- URL url = new URI(endpoint).toURL();
- HttpURLConnection conn = (HttpURLConnection) url.openConnection();
- conn.setRequestMethod("GET");
-
- int responseCode = conn.getResponseCode();
- if (responseCode == HttpURLConnection.HTTP_OK) {
- try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
- StringBuilder response = new StringBuilder();
- String responseLine;
- while ((responseLine = br.readLine()) != null) {
- response.append(responseLine.trim());
- }
- return response.toString();
- }
- } else {
- throw new IOException("Request failed with response code " + responseCode);
- }
- }
-
- // OpenAI chat client, works also with llama.cpp and Ollama
-
- public static String chat(String model, String prompt, int max_tokens) throws IOException {
- JSONObject data = new JSONObject();
- JSONArray messages = new JSONArray();
- JSONObject systemPrompt = new JSONObject(true);
- JSONObject userPrompt = new JSONObject(true);
- messages.put(systemPrompt);
- messages.put(userPrompt);
- try {
- systemPrompt.put("role", "system");
- systemPrompt.put("content", "Make short answers.");
- userPrompt.put("role", "user");
- userPrompt.put("content", prompt);
- data.put("model", model);
- data.put("temperature", 0.1);
- data.put("max_tokens", max_tokens);
- data.put("messages", messages);
- data.put("stop", new JSONArray(STOPTOKENS));
- data.put("stream", false);
- String response = sendPostRequest(LLM_API_HOST + "/v1/chat/completions", data);
- JSONObject responseObject = new JSONObject(response);
- JSONArray choices = responseObject.getJSONArray("choices");
- JSONObject choice = choices.getJSONObject(0);
- JSONObject message = choice.getJSONObject("message");
- String content = message.optString("content", "");
- return content;
- } catch (JSONException | URISyntaxException e) {
- throw new IOException(e.getMessage());
- }
- }
-
- public static String[] stringsFromChat(String answer) {
- int p = answer.indexOf('[');
- int q = answer.indexOf(']');
- if (p < 0 || q < 0 || q < p) return new String[0];
- try {
- JSONArray a = new JSONArray(answer.substring(p, q + 1));
- String[] arr = new String[a.length()];
- for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i);
- return arr;
- } catch (JSONException e) {
- return new String[0];
- }
- }
-
- private static String searchWordsForPrompt(String model, String prompt) {
- StringBuilder query = new StringBuilder();
- String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt;
- try {
- String[] a = stringsFromChat(chat(model, question, 80));
- for (String s: a) query.append(s).append(' ');
- return query.toString().trim();
- } catch (IOException e) {
- e.printStackTrace();
- return "";
- }
- }
-
- private static LinkedHashMap searchResults(String query, int count) {
- Switchboard sb = Switchboard.getSwitchboard();
- EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector();
- // construct query
- final SolrQuery params = new SolrQuery();
- params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query);
- params.setRows(count);
- params.setStart(0);
- params.setFacet(false);
- params.clearSorts();
- params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName());
- params.setIncludeScore(false);
- params.set("df", CollectionSchema.text_t.getSolrFieldName());
-
- // query the server
- try {
- final SolrDocumentList sdl = connector.getDocumentListByParams(params);
- LinkedHashMap a = new LinkedHashMap();
- Iterator i = sdl.iterator();
- while (i.hasNext()) {
- SolrDocument doc = i.next();
- String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName());
- String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName());
- a.put(url, text);
- }
- return a;
- } catch (SolrException | IOException e) {
- return new LinkedHashMap();
- }
- }
-
- // Ollama client functions
-
- public static LinkedHashMap listOllamaModels() {
- LinkedHashMap sortedMap = new LinkedHashMap<>();
- try {
- String response = sendGetRequest(LLM_API_HOST + "/api/tags");
- JSONObject responseObject = new JSONObject(response);
- JSONArray models = responseObject.getJSONArray("models");
-
- List> list = new ArrayList<>();
- for (int i = 0; i < models.length(); i++) {
- JSONObject model = models.getJSONObject(i);
- String name = model.optString("name", "");
- long size = model.optLong("size", 0);
- list.add(new AbstractMap.SimpleEntry(name, size));
- }
-
- // Sort the list in descending order based on the values
- list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
-
- // Create a new LinkedHashMap and add the sorted entries
- for (Map.Entry entry : list) {
- sortedMap.put(entry.getKey(), entry.getValue());
- }
- } catch (JSONException | URISyntaxException | IOException e) {
- e.printStackTrace();
- }
- return sortedMap;
- }
-
- public static boolean ollamaModelExists(String name) {
- JSONObject data = new JSONObject();
- try {
- data.put("name", name);
- sendPostRequest(LLM_API_HOST + "/api/show", data);
- return true;
- } catch (JSONException | URISyntaxException | IOException e) {
- return false;
- }
- }
- public static boolean pullOllamaModel(String name) {
- JSONObject data = new JSONObject();
- try {
- data.put("name", name);
- data.put("stream", false);
- String response = sendPostRequest(LLM_API_HOST + "/api/pull", data);
- // this sends {"status": "success"} in case of success
- JSONObject responseObject = new JSONObject(response);
- String status = responseObject.optString("status", "");
- return status.equals("success");
- } catch (JSONException | URISyntaxException | IOException e) {
- return false;
- }
- }
-
- public static void main(String[] args) {
- LinkedHashMap models = listOllamaModels();
- System.out.println(models.toString());
-
- // check if model exists
- //String model = "phi3:3.8b";
- String model = "gemma:2b";
- if (ollamaModelExists(model))
- System.out.println("model " + model + " exists");
- else
- System.out.println("model " + model + " does not exist");
-
- // pull a model
- boolean success = pullOllamaModel(model);
- System.out.println("pulled model + " + model + ": " + success);
-
- // make chat completion with model
- String question = "Who invented the wheel?";
- try {
- String answer = chat(model, question, 80);
- System.out.println(answer);
- } catch (IOException e) {
- e.printStackTrace();
- }
-
- // try the json parser from chat results
- question = "Make a list of four names from Star Wars movies. Use a JSON Array.";
- try {
- String[] a = stringsFromChat(chat(model, question, 80));
- for (String s: a) System.out.println(s);
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
}