[feat] add generate_session api

This commit is contained in:
SouthFox 2024-12-23 16:00:03 +08:00
parent 9304e24c8c
commit 45ff987051
2 changed files with 67 additions and 15 deletions

40
app.py
View file

@ -2,7 +2,7 @@
import os import os
import vertexai import vertexai
import logging import logging
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel from pydantic import BaseModel
@ -18,16 +18,15 @@ SESSION_CACHE = LRUCache(15)
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
def generate(): def generate(model: str, instruction: str):
vertexai.init(project=PROJECT_ID, location="us-central1") vertexai.init(project=PROJECT_ID, location="us-central1")
model = GenerativeModel( session = GenerativeModel(
"gemini-1.5-flash-001", model,
# system_instruction=[textsi_1] system_instruction=[instruction]
) )
return model.start_chat() return session.start_chat()
# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议"""
generation_config = { generation_config = {
"max_output_tokens": 1024, "max_output_tokens": 1024,
@ -73,6 +72,24 @@ app = FastAPI(
) )
class SessionRequest(BaseModel):
instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\
回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答"
session_uuid: str
model: str = 'gemini-1.5-flash-001'
@app.post("/generate_session")
async def generate_session(request: SessionRequest):
if SESSION_CACHE.get(request.session_uuid):
raise HTTPException(500)
model = generate(request.model, request.instruction)
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
prompt: str prompt: str
session_uuid: str session_uuid: str
@ -80,13 +97,10 @@ class ChatRequest(BaseModel):
@app.post("/generate_text_stream") @app.post("/generate_text_stream")
async def generate_text(request: ChatRequest): async def generate_text(request: ChatRequest):
if model := SESSION_CACHE.get(request.session_uuid): if session := SESSION_CACHE.get(request.session_uuid):
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain")
model = generate() raise HTTPException(404)
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
app.mount("/", StaticFiles(directory="static", html=True), name="static") app.mount("/", StaticFiles(directory="static", html=True), name="static")

View file

@ -43,6 +43,43 @@ async function sendMessage() {
hideTypingIndicator(); hideTypingIndicator();
} }
async function generateSession() {
try {
const response = await fetch("./generate_session", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ session_uuid }),
});
if (!response.ok) {
console.error("Error:", response.statusText);
return "Error occurred while generating response.";
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let isFinished = false;
let generatedTextContent = "";
while (!isFinished) {
const { done, value } = await reader.read();
if (done) {
isFinished = true;
break;
}
generatedTextContent += decoder.decode(value, {stream: true});
}
return generatedTextContent;
} catch (error) {
console.error("Error:", error);
return "An error occurred.";
}
}
async function generateText(prompt) { async function generateText(prompt) {
try { try {
const response = await fetch("./generate_text_stream", { const response = await fetch("./generate_text_stream", {
@ -116,3 +153,4 @@ function handleKeyPress(event) {
} }
window.onload = () => addMessage("This session is: " + session_uuid, 'bot'); window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
generateSession()