feat: hold session
This commit is contained in:
parent
f06d8c80ea
commit
e48af8e172
4 changed files with 21 additions and 6 deletions
16
app.py
16
app.py
|
@ -1,15 +1,21 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import vertexai
|
import vertexai
|
||||||
|
import logging
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
|
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
|
||||||
|
from cachetools import LRUCache
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
PROJECT_ID = os.environ.get('PROJECT_ID')
|
PROJECT_ID = os.environ.get('PROJECT_ID')
|
||||||
|
SESSION_CACHE = LRUCache(15)
|
||||||
|
|
||||||
|
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
@ -48,7 +54,6 @@ safety_settings = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
model = generate()
|
|
||||||
|
|
||||||
def stream_chat(model : ChatSession, _input):
|
def stream_chat(model : ChatSession, _input):
|
||||||
responses = model.send_message(
|
responses = model.send_message(
|
||||||
|
@ -70,11 +75,18 @@ app = FastAPI(
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
session_uuid: str
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate_text_stream")
|
@app.post("/generate_text_stream")
|
||||||
async def generate_text(request : ChatRequest):
|
async def generate_text(request: ChatRequest):
|
||||||
|
if model := SESSION_CACHE.get(request.session_uuid):
|
||||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||||
|
|
||||||
|
model = generate()
|
||||||
|
|
||||||
|
SESSION_CACHE[request.session_uuid] = model
|
||||||
|
logging.info(f"Buid new session: {request.session_uuid}")
|
||||||
|
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||||
|
|
||||||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||||||
|
|
2
poetry.lock
generated
2
poetry.lock
generated
|
@ -1098,4 +1098,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)",
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "d957d855dec6fb30cb9b5d438c605b799d7fdfcfb77769c778bf7a0617627524"
|
content-hash = "dc339ed01958bb7c634e721a47a9a0ba31dd10355285b3bd50594e54bc5d01bb"
|
||||||
|
|
|
@ -11,6 +11,7 @@ google-cloud-aiplatform = "^1.61.0"
|
||||||
fastapi = "^0.112.0"
|
fastapi = "^0.112.0"
|
||||||
uvicorn = "^0.30.5"
|
uvicorn = "^0.30.5"
|
||||||
python-dotenv = "^1.0.1"
|
python-dotenv = "^1.0.1"
|
||||||
|
cachetools = "^5.4.0"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
@ -5,6 +5,8 @@ const chatContainer = document.getElementById("chatContainer");
|
||||||
const typingIndicator = document.getElementById("typingIndicator");
|
const typingIndicator = document.getElementById("typingIndicator");
|
||||||
const sendButton = document.getElementById("sendButton");
|
const sendButton = document.getElementById("sendButton");
|
||||||
|
|
||||||
|
const session_uuid = self.crypto.randomUUID();
|
||||||
|
|
||||||
sendButton.addEventListener("click", () => {
|
sendButton.addEventListener("click", () => {
|
||||||
sendMessage();
|
sendMessage();
|
||||||
}
|
}
|
||||||
|
@ -42,7 +44,7 @@ async function generateText(prompt) {
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ prompt }),
|
body: JSON.stringify({ session_uuid ,prompt }),
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
@ -102,4 +104,4 @@ function handleKeyPress(event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
window.onload = () => addMessage("Hello! How can I assist you today?", 'bot');
|
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
|
||||||
|
|
Loading…
Reference in a new issue