feat: hold session

This commit is contained in:
SouthFox 2024-08-13 13:19:37 +08:00
parent f06d8c80ea
commit e48af8e172
4 changed files with 21 additions and 6 deletions

18
app.py
View file

@ -1,15 +1,21 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import vertexai import vertexai
import logging
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel from pydantic import BaseModel
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
from cachetools import LRUCache
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
PROJECT_ID = os.environ.get('PROJECT_ID') PROJECT_ID = os.environ.get('PROJECT_ID')
SESSION_CACHE = LRUCache(15)
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
def generate(): def generate():
@ -48,7 +54,6 @@ safety_settings = [
), ),
] ]
model = generate()
def stream_chat(model : ChatSession, _input): def stream_chat(model : ChatSession, _input):
responses = model.send_message( responses = model.send_message(
@ -70,11 +75,18 @@ app = FastAPI(
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
prompt: str prompt: str
session_uuid: str
@app.post("/generate_text_stream") @app.post("/generate_text_stream")
async def generate_text(request : ChatRequest): async def generate_text(request: ChatRequest):
if model := SESSION_CACHE.get(request.session_uuid):
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
model = generate()
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
app.mount("/", StaticFiles(directory="static", html=True), name="static") app.mount("/", StaticFiles(directory="static", html=True), name="static")

2
poetry.lock generated
View file

@ -1098,4 +1098,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)",
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "d957d855dec6fb30cb9b5d438c605b799d7fdfcfb77769c778bf7a0617627524" content-hash = "dc339ed01958bb7c634e721a47a9a0ba31dd10355285b3bd50594e54bc5d01bb"

View file

@ -11,6 +11,7 @@ google-cloud-aiplatform = "^1.61.0"
fastapi = "^0.112.0" fastapi = "^0.112.0"
uvicorn = "^0.30.5" uvicorn = "^0.30.5"
python-dotenv = "^1.0.1" python-dotenv = "^1.0.1"
cachetools = "^5.4.0"
[build-system] [build-system]

View file

@ -5,6 +5,8 @@ const chatContainer = document.getElementById("chatContainer");
const typingIndicator = document.getElementById("typingIndicator"); const typingIndicator = document.getElementById("typingIndicator");
const sendButton = document.getElementById("sendButton"); const sendButton = document.getElementById("sendButton");
const session_uuid = self.crypto.randomUUID();
sendButton.addEventListener("click", () => { sendButton.addEventListener("click", () => {
sendMessage(); sendMessage();
} }
@ -42,7 +44,7 @@ async function generateText(prompt) {
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
body: JSON.stringify({ prompt }), body: JSON.stringify({ session_uuid ,prompt }),
}); });
if (!response.ok) { if (!response.ok) {
@ -102,4 +104,4 @@ function handleKeyPress(event) {
} }
} }
window.onload = () => addMessage("Hello! How can I assist you today?", 'bot'); window.onload = () => addMessage("This session is: " + session_uuid, 'bot');