From e48af8e172753141fcb134d19dfa67ac60cda2e5 Mon Sep 17 00:00:00 2001 From: SouthFox Date: Tue, 13 Aug 2024 13:19:37 +0800 Subject: [PATCH] feat: hold session --- app.py | 18 +++++++++++++++--- poetry.lock | 2 +- pyproject.toml | 1 + static/script.js | 6 ++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/app.py b/app.py index 9281269..ef9bc1c 100644 --- a/app.py +++ b/app.py @@ -1,15 +1,21 @@ #!/usr/bin/env python3 import os import vertexai +import logging from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting +from cachetools import LRUCache from dotenv import load_dotenv + load_dotenv() PROJECT_ID = os.environ.get('PROJECT_ID') +SESSION_CACHE = LRUCache(15) + +logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) def generate(): @@ -48,7 +54,6 @@ safety_settings = [ ), ] -model = generate() def stream_chat(model : ChatSession, _input): responses = model.send_message( @@ -70,11 +75,18 @@ app = FastAPI( class ChatRequest(BaseModel): prompt: str + session_uuid: str @app.post("/generate_text_stream") -async def generate_text(request : ChatRequest): +async def generate_text(request: ChatRequest): + if model := SESSION_CACHE.get(request.session_uuid): + return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") + + model = generate() + + SESSION_CACHE[request.session_uuid] = model + logging.info(f"Buid new session: {request.session_uuid}") return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") - app.mount("/", StaticFiles(directory="static", html=True), name="static") diff --git a/poetry.lock b/poetry.lock index c0e4dbf..fe30194 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1098,4 +1098,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d957d855dec6fb30cb9b5d438c605b799d7fdfcfb77769c778bf7a0617627524" +content-hash = "dc339ed01958bb7c634e721a47a9a0ba31dd10355285b3bd50594e54bc5d01bb" diff --git a/pyproject.toml b/pyproject.toml index d2a2eb3..b47e5d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ google-cloud-aiplatform = "^1.61.0" fastapi = "^0.112.0" uvicorn = "^0.30.5" python-dotenv = "^1.0.1" +cachetools = "^5.4.0" [build-system] diff --git a/static/script.js b/static/script.js index f52ad7c..bcf98d0 100644 --- a/static/script.js +++ b/static/script.js @@ -5,6 +5,8 @@ const chatContainer = document.getElementById("chatContainer"); const typingIndicator = document.getElementById("typingIndicator"); const sendButton = document.getElementById("sendButton"); +const session_uuid = self.crypto.randomUUID(); + sendButton.addEventListener("click", () => { sendMessage(); } @@ -42,7 +44,7 @@ async function generateText(prompt) { headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ prompt }), + body: JSON.stringify({ session_uuid ,prompt }), }); if (!response.ok) { @@ -102,4 +104,4 @@ function handleKeyPress(event) { } } -window.onload = () => addMessage("Hello! How can I assist you today?", 'bot'); +window.onload = () => addMessage("This session is: " + session_uuid, 'bot');