diff --git a/defaults/web.xml b/defaults/web.xml index 9a6fbec39..43a500cc1 100644 --- a/defaults/web.xml +++ b/defaults/web.xml @@ -50,6 +50,11 @@ net.yacy.http.servlets.SolrServlet + + RAGProxyServlet + net.yacy.http.servlets.RAGProxyServlet + + URLProxyServlet @@ -81,7 +86,12 @@ /solr/webgraph/admin/luke - + + + + RAGProxyServlet + /v1/chat/completions + diff --git a/source/net/yacy/http/servlets/RAGProxyServlet.java b/source/net/yacy/http/servlets/RAGProxyServlet.java new file mode 100644 index 000000000..573c1cc2d --- /dev/null +++ b/source/net/yacy/http/servlets/RAGProxyServlet.java @@ -0,0 +1,434 @@ +/** + * RAGProxyServlet + * Copyright 2024 by Michael Peter Christen + * First released 17.05.2024 at http://yacy.net + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program in the file lgpl21.txt + * If not, see . + */ + +package net.yacy.http.servlets; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector; +import net.yacy.search.Switchboard; +import net.yacy.search.schema.CollectionSchema; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrDocumentList; +import org.apache.solr.common.SolrException; +import org.apache.solr.servlet.cache.Method; + +import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * This class implements a Retrieval Augmented Generation ("RAG") proxy which uses a YaCy search index + * to enrich a chat with search results. The + */ +public class RAGProxyServlet extends HttpServlet { + + private static final long serialVersionUID = 3411544789759603107L; + private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"}; + + private static Boolean LLM_ENABLED = false; + private static Boolean LLM_CONTROL_OLLAMA = true; + private static Boolean LLM_ATTACH_QUERY = false; // instructs the proxy to attach the prompt generated to do the RAG search + private static Boolean LLM_ATTACH_REFERENCES = false; // instructs the proxy to attach a list of sources that had been used in RAG + private static String LLM_LANGUAGE = "en"; // used to select proper language in RAG augmentation + private static String LLM_SYSTEM_PREFIX = "\n\nYou may receive additional expert knowledge in the user prompt after a 'Additional Information' headline to enhance your knowledge. Use it only if applicable."; + private static String LLM_USER_PREFIX = "\n\nAdditional Information:\n\nbelow you find a collection of texts that might be useful to generate a response. Do not discuss these documents, just use them to answer the question above.\n\n"; + private static String LLM_API_HOST = "http://localhost:11434"; // Ollama port; install ollama from https://ollama.com/ + private static String LLM_QUERY_MODEL = "phi3:3.8b"; + private static String LLM_ANSWER_MODEL = "llama3:8b"; // or "phi3:3.8b" i.e. on a Raspberry Pi 5 + private static Boolean LLM_API_MODEL_OVERWRITING = true; // if true, the value configured in YaCy overwrites the client model + private static String LLM_API_KEY = ""; // not required; option to use this class to use a OpenAI API + + @Override + public void service(ServletRequest request, ServletResponse response) throws IOException, ServletException { + response.setContentType("application/json;charset=utf-8"); + + HttpServletResponse hresponse = (HttpServletResponse) response; + HttpServletRequest hrequest = (HttpServletRequest) request; + + // Add CORS headers + hresponse.setHeader("Access-Control-Allow-Origin", "*"); + hresponse.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE"); + hresponse.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization"); + + final Method reqMethod = Method.getMethod(hrequest.getMethod()); + if (reqMethod == Method.OTHER) { + // required to handle CORS + hresponse.setStatus(HttpServletResponse.SC_OK); + return; + } + + // We expect a POST request + if (reqMethod != Method.POST) { + hresponse.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + // get the output stream early to be able to generate messages to the user before the actual retrieval starts + ServletOutputStream out = response.getOutputStream(); + + // read the body of the request and parse it as JSON + BufferedReader reader = request.getReader(); + StringBuilder bodyBuilder = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + bodyBuilder.append(line); + } + String body = bodyBuilder.toString(); + JSONObject bodyObject; + try { + // get system message and user prompt + bodyObject = new JSONObject(body); + String model = bodyObject.optString("model", LLM_ANSWER_MODEL); // we need a switch to allow overwriting + JSONArray messages = bodyObject.optJSONArray("messages"); + JSONObject systemObject = messages.getJSONObject(0); + String system = systemObject.optString("content", ""); // the system prompt + JSONObject userObject = messages.getJSONObject(messages.length() - 1); + String user = userObject.optString("content", ""); // this is the latest prompt + + // modify system and user prompt here in bodyObject to enable RAG + String query = searchWordsForPrompt(LLM_QUERY_MODEL, user); + out.print(responseLine("Searching for '" + query + "'\n\n").toString() + "\n"); out.flush(); + LinkedHashMap searchResults = searchResults(query, 4); + out.print(responseLine("Using the following sources for RAG:\n\n").toString() + "\n"); out.flush(); + for (String s: searchResults.keySet()) {out.print(responseLine("- `" + s + "`\n").toString() + "\n"); out.flush();} + out.print(responseLine("\n").toString()); out.flush(); + system += LLM_SYSTEM_PREFIX; + user += LLM_USER_PREFIX; + for (String s: searchResults.values()) user += s + "\n\n"; + systemObject.put("content", system); + userObject.put("content", user); + + if (LLM_API_MODEL_OVERWRITING) bodyObject.put("model", LLM_ANSWER_MODEL); + + // write back modified bodyMap to body + body = bodyObject.toString(); + + // Open request to back-end service + URL url = new URI(LLM_API_HOST + "/v1/chat/completions").toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + if (!LLM_API_KEY.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + LLM_API_KEY); + } + conn.setDoOutput(true); + + // write the body to back-end LLM + try (OutputStream os = conn.getOutputStream()) { + os.write(body.getBytes()); + os.flush(); + } + + // write back response of the back-end service to the client; use status of backend-response + int status = conn.getResponseCode(); + String rmessage = conn.getResponseMessage(); + hresponse.setStatus(status); + + if (status == 200) { + // read the response of the back-end line-by-line and write it to the client line-by-line + BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream())); + String inputLine; + while ((inputLine = in.readLine()) != null) { + out.print(inputLine); // i.e. data: {"id":"chatcmpl-69","object":"chat.completion.chunk","created":1715908287,"model":"llama3:8b","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":"ߘŠ"},"finish_reason":null}]} + out.flush(); + } + in.close(); + } + out.close(); // close this here to end transmission + } catch (JSONException | URISyntaxException e) { + throw new IOException(e.getMessage()); + } + } + + private static JSONObject responseLine(String payload) { + JSONObject j = new JSONObject(true); + try { + j.put("id", "log"); + j.put("object", "chat.completion.chunk"); + j.put("created", System.currentTimeMillis() / 1000); + j.put("model", "log"); + j.put("system_fingerprint", "YaCy"); + JSONArray choices = new JSONArray(); + JSONObject choice = new JSONObject(true); // {"index":0,"delta":{"role":"assistant","content":"ߘŠ" + choice.put("index", 0); + JSONObject delta = new JSONObject(true); + delta.put("role", "assistant"); + delta.put("content", payload); + choice.put("delta", delta); + choices.put(choice); + j.put("choices", choices); + //j.put("finish_reason", null); // this is problematic with the JSON library + } catch (JSONException e) {} + return j; + } + + // API Helper Methods for Ollama + + private static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setDoOutput(true); + + try (OutputStream os = conn.getOutputStream()) { + byte[] input = data.toString().getBytes("utf-8"); + os.write(input, 0, input.length); + } + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + private static String sendGetRequest(String endpoint) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("GET"); + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + // OpenAI chat client, works also with llama.cpp and Ollama + + public static String chat(String model, String prompt, int max_tokens) throws IOException { + JSONObject data = new JSONObject(); + JSONArray messages = new JSONArray(); + JSONObject systemPrompt = new JSONObject(true); + JSONObject userPrompt = new JSONObject(true); + messages.put(systemPrompt); + messages.put(userPrompt); + try { + systemPrompt.put("role", "system"); + systemPrompt.put("content", "Make short answers."); + userPrompt.put("role", "user"); + userPrompt.put("content", prompt); + data.put("model", model); + data.put("temperature", 0.1); + data.put("max_tokens", max_tokens); + data.put("messages", messages); + data.put("stop", new JSONArray(STOPTOKENS)); + data.put("stream", false); + String response = sendPostRequest(LLM_API_HOST + "/v1/chat/completions", data); + JSONObject responseObject = new JSONObject(response); + JSONArray choices = responseObject.getJSONArray("choices"); + JSONObject choice = choices.getJSONObject(0); + JSONObject message = choice.getJSONObject("message"); + String content = message.optString("content", ""); + return content; + } catch (JSONException | URISyntaxException e) { + throw new IOException(e.getMessage()); + } + } + + public static String[] stringsFromChat(String answer) { + int p = answer.indexOf('['); + int q = answer.indexOf(']'); + if (p < 0 || q < 0 || q < p) return new String[0]; + try { + JSONArray a = new JSONArray(answer.substring(p, q + 1)); + String[] arr = new String[a.length()]; + for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i); + return arr; + } catch (JSONException e) { + return new String[0]; + } + } + + private static String searchWordsForPrompt(String model, String prompt) { + StringBuilder query = new StringBuilder(); + String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt; + try { + String[] a = stringsFromChat(chat(model, question, 80)); + for (String s: a) query.append(s).append(' '); + return query.toString().trim(); + } catch (IOException e) { + e.printStackTrace(); + return ""; + } + } + + private static LinkedHashMap searchResults(String query, int count) { + Switchboard sb = Switchboard.getSwitchboard(); + EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector(); + // construct query + final SolrQuery params = new SolrQuery(); + params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query); + params.setRows(count); + params.setStart(0); + params.setFacet(false); + params.clearSorts(); + params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName()); + params.setIncludeScore(false); + params.set("df", CollectionSchema.text_t.getSolrFieldName()); + + // query the server + try { + final SolrDocumentList sdl = connector.getDocumentListByParams(params); + LinkedHashMap a = new LinkedHashMap(); + Iterator i = sdl.iterator(); + while (i.hasNext()) { + SolrDocument doc = i.next(); + String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName()); + String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName()); + a.put(url, text); + } + return a; + } catch (SolrException | IOException e) { + return new LinkedHashMap(); + } + } + + // Ollama client functions + + public static LinkedHashMap listOllamaModels() { + LinkedHashMap sortedMap = new LinkedHashMap<>(); + try { + String response = sendGetRequest(LLM_API_HOST + "/api/tags"); + JSONObject responseObject = new JSONObject(response); + JSONArray models = responseObject.getJSONArray("models"); + + List> list = new ArrayList<>(); + for (int i = 0; i < models.length(); i++) { + JSONObject model = models.getJSONObject(i); + String name = model.optString("name", ""); + long size = model.optLong("size", 0); + list.add(new AbstractMap.SimpleEntry(name, size)); + } + + // Sort the list in descending order based on the values + list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue())); + + // Create a new LinkedHashMap and add the sorted entries + for (Map.Entry entry : list) { + sortedMap.put(entry.getKey(), entry.getValue()); + } + } catch (JSONException | URISyntaxException | IOException e) { + e.printStackTrace(); + } + return sortedMap; + } + + public static boolean ollamaModelExists(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + sendPostRequest(LLM_API_HOST + "/api/show", data); + return true; + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public static boolean pullOllamaModel(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + data.put("stream", false); + String response = sendPostRequest(LLM_API_HOST + "/api/pull", data); + // this sends {"status": "success"} in case of success + JSONObject responseObject = new JSONObject(response); + String status = responseObject.optString("status", ""); + return status.equals("success"); + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public static void main(String[] args) { + LinkedHashMap models = listOllamaModels(); + System.out.println(models.toString()); + + // check if model exists + //String model = "phi3:3.8b"; + String model = "gemma:2b"; + if (ollamaModelExists(model)) + System.out.println("model " + model + " exists"); + else + System.out.println("model " + model + " does not exist"); + + // pull a model + boolean success = pullOllamaModel(model); + System.out.println("pulled model + " + model + ": " + success); + + // make chat completion with model + String question = "Who invented the wheel?"; + try { + String answer = chat(model, question, 80); + System.out.println(answer); + } catch (IOException e) { + e.printStackTrace(); + } + + // try the json parser from chat results + question = "Make a list of four names from Star Wars movies. Use a JSON Array."; + try { + String[] a = stringsFromChat(chat(model, question, 80)); + for (String s: a) System.out.println(s); + } catch (IOException e) { + e.printStackTrace(); + } + } +}