feat: hold session
This commit is contained in:
parent
f06d8c80ea
commit
e48af8e172
4 changed files with 21 additions and 6 deletions
18
app.py
18
app.py
|
@ -1,15 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import vertexai
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
|
||||
from cachetools import LRUCache
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
load_dotenv()
|
||||
PROJECT_ID = os.environ.get('PROJECT_ID')
|
||||
SESSION_CACHE = LRUCache(15)
|
||||
|
||||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
||||
|
||||
|
||||
def generate():
|
||||
|
@ -48,7 +54,6 @@ safety_settings = [
|
|||
),
|
||||
]
|
||||
|
||||
model = generate()
|
||||
|
||||
def stream_chat(model : ChatSession, _input):
|
||||
responses = model.send_message(
|
||||
|
@ -70,11 +75,18 @@ app = FastAPI(
|
|||
|
||||
class ChatRequest(BaseModel):
|
||||
prompt: str
|
||||
session_uuid: str
|
||||
|
||||
|
||||
@app.post("/generate_text_stream")
|
||||
async def generate_text(request : ChatRequest):
|
||||
async def generate_text(request: ChatRequest):
|
||||
if model := SESSION_CACHE.get(request.session_uuid):
|
||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||
|
||||
model = generate()
|
||||
|
||||
SESSION_CACHE[request.session_uuid] = model
|
||||
logging.info(f"Buid new session: {request.session_uuid}")
|
||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||
|
||||
|
||||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||||
|
|
2
poetry.lock
generated
2
poetry.lock
generated
|
@ -1098,4 +1098,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)",
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "d957d855dec6fb30cb9b5d438c605b799d7fdfcfb77769c778bf7a0617627524"
|
||||
content-hash = "dc339ed01958bb7c634e721a47a9a0ba31dd10355285b3bd50594e54bc5d01bb"
|
||||
|
|
|
@ -11,6 +11,7 @@ google-cloud-aiplatform = "^1.61.0"
|
|||
fastapi = "^0.112.0"
|
||||
uvicorn = "^0.30.5"
|
||||
python-dotenv = "^1.0.1"
|
||||
cachetools = "^5.4.0"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
|
|
@ -5,6 +5,8 @@ const chatContainer = document.getElementById("chatContainer");
|
|||
const typingIndicator = document.getElementById("typingIndicator");
|
||||
const sendButton = document.getElementById("sendButton");
|
||||
|
||||
const session_uuid = self.crypto.randomUUID();
|
||||
|
||||
sendButton.addEventListener("click", () => {
|
||||
sendMessage();
|
||||
}
|
||||
|
@ -42,7 +44,7 @@ async function generateText(prompt) {
|
|||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ prompt }),
|
||||
body: JSON.stringify({ session_uuid ,prompt }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
@ -102,4 +104,4 @@ function handleKeyPress(event) {
|
|||
}
|
||||
}
|
||||
|
||||
window.onload = () => addMessage("Hello! How can I assist you today?", 'bot');
|
||||
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
|
||||
|
|
Loading…
Reference in a new issue