[feat] add generate_session api

This commit is contained in:
SouthFox 2024-12-23 16:00:03 +08:00
parent 9304e24c8c
commit 45ff987051
2 changed files with 67 additions and 15 deletions

40
app.py
View file

@ -2,7 +2,7 @@
import os
import vertexai
import logging
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
@ -18,16 +18,15 @@ SESSION_CACHE = LRUCache(15)
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
def generate():
def generate(model: str, instruction: str):
vertexai.init(project=PROJECT_ID, location="us-central1")
model = GenerativeModel(
"gemini-1.5-flash-001",
# system_instruction=[textsi_1]
session = GenerativeModel(
model,
system_instruction=[instruction]
)
return model.start_chat()
return session.start_chat()
# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议"""
generation_config = {
"max_output_tokens": 1024,
@ -73,6 +72,24 @@ app = FastAPI(
)
class SessionRequest(BaseModel):
instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\
回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答"
session_uuid: str
model: str = 'gemini-1.5-flash-001'
@app.post("/generate_session")
async def generate_session(request: SessionRequest):
if SESSION_CACHE.get(request.session_uuid):
raise HTTPException(500)
model = generate(request.model, request.instruction)
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
class ChatRequest(BaseModel):
prompt: str
session_uuid: str
@ -80,13 +97,10 @@ class ChatRequest(BaseModel):
@app.post("/generate_text_stream")
async def generate_text(request: ChatRequest):
if model := SESSION_CACHE.get(request.session_uuid):
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
if session := SESSION_CACHE.get(request.session_uuid):
return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain")
model = generate()
raise HTTPException(404)
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
app.mount("/", StaticFiles(directory="static", html=True), name="static")

View file

@ -43,6 +43,43 @@ async function sendMessage() {
hideTypingIndicator();
}
async function generateSession() {
try {
const response = await fetch("./generate_session", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ session_uuid }),
});
if (!response.ok) {
console.error("Error:", response.statusText);
return "Error occurred while generating response.";
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let isFinished = false;
let generatedTextContent = "";
while (!isFinished) {
const { done, value } = await reader.read();
if (done) {
isFinished = true;
break;
}
generatedTextContent += decoder.decode(value, {stream: true});
}
return generatedTextContent;
} catch (error) {
console.error("Error:", error);
return "An error occurred.";
}
}
async function generateText(prompt) {
try {
const response = await fetch("./generate_text_stream", {
@ -116,3 +153,4 @@ function handleKeyPress(event) {
}
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
generateSession()