From 13fbff0bffc78aa4676e92eb0ff95e5a9e98872f Mon Sep 17 00:00:00 2001 From: Michael Peter Christen Date: Sun, 19 May 2024 17:19:09 +0200 Subject: [PATCH] Added a RAG Proxy for AI Chat with YaCy RAG (Retrieval Augmented Generation) is a method to combine a search engine with a LLM (Large Language Model). When a new prompt is submitted, a search engine injects knowledge from a search into the content. This is done using a reverse proxy between the Chat Client and the LLM. In this case, we used the following software: LLM Backend - Ollama: https://github.com/ollama/ollama Install ollama and then load two required LLM models with the following commands: ollama pull phi3:3.8b ollama pull llama3:8b Chat Client - susi_chat: https://github.com/susiai/susi_chat just clone the repository and the open the file susi_chat/chat_terminal/index.html in your browser. This displays a chat terminal. In this terminal, run the following command: host http://localhost:8090 This sets the LLM backend to your YaCy peer. Then start YaCy. It will provide the LLM endpoint to the client while using ollama in the backend. It then injects search results only from the local Solr index, not from the p2p network (so far). --- defaults/web.xml | 12 +- .../yacy/http/servlets/RAGProxyServlet.java | 434 ++++++++++++++++++ 2 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 source/net/yacy/http/servlets/RAGProxyServlet.java diff --git a/defaults/web.xml b/defaults/web.xml index 9a6fbec39..43a500cc1 100644 --- a/defaults/web.xml +++ b/defaults/web.xml @@ -50,6 +50,11 @@ net.yacy.http.servlets.SolrServlet + + RAGProxyServlet + net.yacy.http.servlets.RAGProxyServlet + + URLProxyServlet @@ -81,7 +86,12 @@ /solr/webgraph/admin/luke - + + + + RAGProxyServlet + /v1/chat/completions + diff --git a/source/net/yacy/http/servlets/RAGProxyServlet.java b/source/net/yacy/http/servlets/RAGProxyServlet.java new file mode 100644 index 000000000..573c1cc2d --- /dev/null +++ b/source/net/yacy/http/servlets/RAGProxyServlet.java @@ -0,0 +1,434 @@ +/** + * RAGProxyServlet + * Copyright 2024 by Michael Peter Christen + * First released 17.05.2024 at http://yacy.net + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program in the file lgpl21.txt + * If not, see . + */ + +package net.yacy.http.servlets; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector; +import net.yacy.search.Switchboard; +import net.yacy.search.schema.CollectionSchema; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrDocumentList; +import org.apache.solr.common.SolrException; +import org.apache.solr.servlet.cache.Method; + +import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * This class implements a Retrieval Augmented Generation ("RAG") proxy which uses a YaCy search index + * to enrich a chat with search results. The + */ +public class RAGProxyServlet extends HttpServlet { + + private static final long serialVersionUID = 3411544789759603107L; + private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"}; + + private static Boolean LLM_ENABLED = false; + private static Boolean LLM_CONTROL_OLLAMA = true; + private static Boolean LLM_ATTACH_QUERY = false; // instructs the proxy to attach the prompt generated to do the RAG search + private static Boolean LLM_ATTACH_REFERENCES = false; // instructs the proxy to attach a list of sources that had been used in RAG + private static String LLM_LANGUAGE = "en"; // used to select proper language in RAG augmentation + private static String LLM_SYSTEM_PREFIX = "\n\nYou may receive additional expert knowledge in the user prompt after a 'Additional Information' headline to enhance your knowledge. Use it only if applicable."; + private static String LLM_USER_PREFIX = "\n\nAdditional Information:\n\nbelow you find a collection of texts that might be useful to generate a response. Do not discuss these documents, just use them to answer the question above.\n\n"; + private static String LLM_API_HOST = "http://localhost:11434"; // Ollama port; install ollama from https://ollama.com/ + private static String LLM_QUERY_MODEL = "phi3:3.8b"; + private static String LLM_ANSWER_MODEL = "llama3:8b"; // or "phi3:3.8b" i.e. on a Raspberry Pi 5 + private static Boolean LLM_API_MODEL_OVERWRITING = true; // if true, the value configured in YaCy overwrites the client model + private static String LLM_API_KEY = ""; // not required; option to use this class to use a OpenAI API + + @Override + public void service(ServletRequest request, ServletResponse response) throws IOException, ServletException { + response.setContentType("application/json;charset=utf-8"); + + HttpServletResponse hresponse = (HttpServletResponse) response; + HttpServletRequest hrequest = (HttpServletRequest) request; + + // Add CORS headers + hresponse.setHeader("Access-Control-Allow-Origin", "*"); + hresponse.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE"); + hresponse.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization"); + + final Method reqMethod = Method.getMethod(hrequest.getMethod()); + if (reqMethod == Method.OTHER) { + // required to handle CORS + hresponse.setStatus(HttpServletResponse.SC_OK); + return; + } + + // We expect a POST request + if (reqMethod != Method.POST) { + hresponse.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + // get the output stream early to be able to generate messages to the user before the actual retrieval starts + ServletOutputStream out = response.getOutputStream(); + + // read the body of the request and parse it as JSON + BufferedReader reader = request.getReader(); + StringBuilder bodyBuilder = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + bodyBuilder.append(line); + } + String body = bodyBuilder.toString(); + JSONObject bodyObject; + try { + // get system message and user prompt + bodyObject = new JSONObject(body); + String model = bodyObject.optString("model", LLM_ANSWER_MODEL); // we need a switch to allow overwriting + JSONArray messages = bodyObject.optJSONArray("messages"); + JSONObject systemObject = messages.getJSONObject(0); + String system = systemObject.optString("content", ""); // the system prompt + JSONObject userObject = messages.getJSONObject(messages.length() - 1); + String user = userObject.optString("content", ""); // this is the latest prompt + + // modify system and user prompt here in bodyObject to enable RAG + String query = searchWordsForPrompt(LLM_QUERY_MODEL, user); + out.print(responseLine("Searching for '" + query + "'\n\n").toString() + "\n"); out.flush(); + LinkedHashMap searchResults = searchResults(query, 4); + out.print(responseLine("Using the following sources for RAG:\n\n").toString() + "\n"); out.flush(); + for (String s: searchResults.keySet()) {out.print(responseLine("- `" + s + "`\n").toString() + "\n"); out.flush();} + out.print(responseLine("\n").toString()); out.flush(); + system += LLM_SYSTEM_PREFIX; + user += LLM_USER_PREFIX; + for (String s: searchResults.values()) user += s + "\n\n"; + systemObject.put("content", system); + userObject.put("content", user); + + if (LLM_API_MODEL_OVERWRITING) bodyObject.put("model", LLM_ANSWER_MODEL); + + // write back modified bodyMap to body + body = bodyObject.toString(); + + // Open request to back-end service + URL url = new URI(LLM_API_HOST + "/v1/chat/completions").toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + if (!LLM_API_KEY.isEmpty()) { + conn.setRequestProperty("Authorization", "Bearer " + LLM_API_KEY); + } + conn.setDoOutput(true); + + // write the body to back-end LLM + try (OutputStream os = conn.getOutputStream()) { + os.write(body.getBytes()); + os.flush(); + } + + // write back response of the back-end service to the client; use status of backend-response + int status = conn.getResponseCode(); + String rmessage = conn.getResponseMessage(); + hresponse.setStatus(status); + + if (status == 200) { + // read the response of the back-end line-by-line and write it to the client line-by-line + BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream())); + String inputLine; + while ((inputLine = in.readLine()) != null) { + out.print(inputLine); // i.e. data: {"id":"chatcmpl-69","object":"chat.completion.chunk","created":1715908287,"model":"llama3:8b","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":"ߘŠ"},"finish_reason":null}]} + out.flush(); + } + in.close(); + } + out.close(); // close this here to end transmission + } catch (JSONException | URISyntaxException e) { + throw new IOException(e.getMessage()); + } + } + + private static JSONObject responseLine(String payload) { + JSONObject j = new JSONObject(true); + try { + j.put("id", "log"); + j.put("object", "chat.completion.chunk"); + j.put("created", System.currentTimeMillis() / 1000); + j.put("model", "log"); + j.put("system_fingerprint", "YaCy"); + JSONArray choices = new JSONArray(); + JSONObject choice = new JSONObject(true); // {"index":0,"delta":{"role":"assistant","content":"ߘŠ" + choice.put("index", 0); + JSONObject delta = new JSONObject(true); + delta.put("role", "assistant"); + delta.put("content", payload); + choice.put("delta", delta); + choices.put(choice); + j.put("choices", choices); + //j.put("finish_reason", null); // this is problematic with the JSON library + } catch (JSONException e) {} + return j; + } + + // API Helper Methods for Ollama + + private static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setDoOutput(true); + + try (OutputStream os = conn.getOutputStream()) { + byte[] input = data.toString().getBytes("utf-8"); + os.write(input, 0, input.length); + } + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + private static String sendGetRequest(String endpoint) throws IOException, URISyntaxException { + URL url = new URI(endpoint).toURL(); + HttpURLConnection conn = (HttpURLConnection) url.openConnection(); + conn.setRequestMethod("GET"); + + int responseCode = conn.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) { + StringBuilder response = new StringBuilder(); + String responseLine; + while ((responseLine = br.readLine()) != null) { + response.append(responseLine.trim()); + } + return response.toString(); + } + } else { + throw new IOException("Request failed with response code " + responseCode); + } + } + + // OpenAI chat client, works also with llama.cpp and Ollama + + public static String chat(String model, String prompt, int max_tokens) throws IOException { + JSONObject data = new JSONObject(); + JSONArray messages = new JSONArray(); + JSONObject systemPrompt = new JSONObject(true); + JSONObject userPrompt = new JSONObject(true); + messages.put(systemPrompt); + messages.put(userPrompt); + try { + systemPrompt.put("role", "system"); + systemPrompt.put("content", "Make short answers."); + userPrompt.put("role", "user"); + userPrompt.put("content", prompt); + data.put("model", model); + data.put("temperature", 0.1); + data.put("max_tokens", max_tokens); + data.put("messages", messages); + data.put("stop", new JSONArray(STOPTOKENS)); + data.put("stream", false); + String response = sendPostRequest(LLM_API_HOST + "/v1/chat/completions", data); + JSONObject responseObject = new JSONObject(response); + JSONArray choices = responseObject.getJSONArray("choices"); + JSONObject choice = choices.getJSONObject(0); + JSONObject message = choice.getJSONObject("message"); + String content = message.optString("content", ""); + return content; + } catch (JSONException | URISyntaxException e) { + throw new IOException(e.getMessage()); + } + } + + public static String[] stringsFromChat(String answer) { + int p = answer.indexOf('['); + int q = answer.indexOf(']'); + if (p < 0 || q < 0 || q < p) return new String[0]; + try { + JSONArray a = new JSONArray(answer.substring(p, q + 1)); + String[] arr = new String[a.length()]; + for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i); + return arr; + } catch (JSONException e) { + return new String[0]; + } + } + + private static String searchWordsForPrompt(String model, String prompt) { + StringBuilder query = new StringBuilder(); + String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt; + try { + String[] a = stringsFromChat(chat(model, question, 80)); + for (String s: a) query.append(s).append(' '); + return query.toString().trim(); + } catch (IOException e) { + e.printStackTrace(); + return ""; + } + } + + private static LinkedHashMap searchResults(String query, int count) { + Switchboard sb = Switchboard.getSwitchboard(); + EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector(); + // construct query + final SolrQuery params = new SolrQuery(); + params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query); + params.setRows(count); + params.setStart(0); + params.setFacet(false); + params.clearSorts(); + params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName()); + params.setIncludeScore(false); + params.set("df", CollectionSchema.text_t.getSolrFieldName()); + + // query the server + try { + final SolrDocumentList sdl = connector.getDocumentListByParams(params); + LinkedHashMap a = new LinkedHashMap(); + Iterator i = sdl.iterator(); + while (i.hasNext()) { + SolrDocument doc = i.next(); + String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName()); + String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName()); + a.put(url, text); + } + return a; + } catch (SolrException | IOException e) { + return new LinkedHashMap(); + } + } + + // Ollama client functions + + public static LinkedHashMap listOllamaModels() { + LinkedHashMap sortedMap = new LinkedHashMap<>(); + try { + String response = sendGetRequest(LLM_API_HOST + "/api/tags"); + JSONObject responseObject = new JSONObject(response); + JSONArray models = responseObject.getJSONArray("models"); + + List> list = new ArrayList<>(); + for (int i = 0; i < models.length(); i++) { + JSONObject model = models.getJSONObject(i); + String name = model.optString("name", ""); + long size = model.optLong("size", 0); + list.add(new AbstractMap.SimpleEntry(name, size)); + } + + // Sort the list in descending order based on the values + list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue())); + + // Create a new LinkedHashMap and add the sorted entries + for (Map.Entry entry : list) { + sortedMap.put(entry.getKey(), entry.getValue()); + } + } catch (JSONException | URISyntaxException | IOException e) { + e.printStackTrace(); + } + return sortedMap; + } + + public static boolean ollamaModelExists(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + sendPostRequest(LLM_API_HOST + "/api/show", data); + return true; + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public static boolean pullOllamaModel(String name) { + JSONObject data = new JSONObject(); + try { + data.put("name", name); + data.put("stream", false); + String response = sendPostRequest(LLM_API_HOST + "/api/pull", data); + // this sends {"status": "success"} in case of success + JSONObject responseObject = new JSONObject(response); + String status = responseObject.optString("status", ""); + return status.equals("success"); + } catch (JSONException | URISyntaxException | IOException e) { + return false; + } + } + + public static void main(String[] args) { + LinkedHashMap models = listOllamaModels(); + System.out.println(models.toString()); + + // check if model exists + //String model = "phi3:3.8b"; + String model = "gemma:2b"; + if (ollamaModelExists(model)) + System.out.println("model " + model + " exists"); + else + System.out.println("model " + model + " does not exist"); + + // pull a model + boolean success = pullOllamaModel(model); + System.out.println("pulled model + " + model + ": " + success); + + // make chat completion with model + String question = "Who invented the wheel?"; + try { + String answer = chat(model, question, 80); + System.out.println(answer); + } catch (IOException e) { + e.printStackTrace(); + } + + // try the json parser from chat results + question = "Make a list of four names from Star Wars movies. Use a JSON Array."; + try { + String[] a = stringsFromChat(chat(model, question, 80)); + for (String s: a) System.out.println(s); + } catch (IOException e) { + e.printStackTrace(); + } + } +}