[feat] add generate_session api
This commit is contained in:
parent
9304e24c8c
commit
45ff987051
2 changed files with 67 additions and 15 deletions
44
app.py
44
app.py
|
@ -2,7 +2,7 @@
|
|||
import os
|
||||
import vertexai
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
@ -18,16 +18,15 @@ SESSION_CACHE = LRUCache(15)
|
|||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
||||
|
||||
|
||||
def generate():
|
||||
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||
model = GenerativeModel(
|
||||
"gemini-1.5-flash-001",
|
||||
# system_instruction=[textsi_1]
|
||||
)
|
||||
def generate(model: str, instruction: str):
|
||||
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||
session = GenerativeModel(
|
||||
model,
|
||||
system_instruction=[instruction]
|
||||
)
|
||||
|
||||
return model.start_chat()
|
||||
return session.start_chat()
|
||||
|
||||
# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议"""
|
||||
|
||||
generation_config = {
|
||||
"max_output_tokens": 1024,
|
||||
|
@ -73,6 +72,24 @@ app = FastAPI(
|
|||
)
|
||||
|
||||
|
||||
class SessionRequest(BaseModel):
|
||||
instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\
|
||||
回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答。"
|
||||
session_uuid: str
|
||||
model: str = 'gemini-1.5-flash-001'
|
||||
|
||||
|
||||
@app.post("/generate_session")
|
||||
async def generate_session(request: SessionRequest):
|
||||
if SESSION_CACHE.get(request.session_uuid):
|
||||
raise HTTPException(500)
|
||||
|
||||
model = generate(request.model, request.instruction)
|
||||
|
||||
SESSION_CACHE[request.session_uuid] = model
|
||||
logging.info(f"Buid new session: {request.session_uuid}")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
prompt: str
|
||||
session_uuid: str
|
||||
|
@ -80,13 +97,10 @@ class ChatRequest(BaseModel):
|
|||
|
||||
@app.post("/generate_text_stream")
|
||||
async def generate_text(request: ChatRequest):
|
||||
if model := SESSION_CACHE.get(request.session_uuid):
|
||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||
if session := SESSION_CACHE.get(request.session_uuid):
|
||||
return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain")
|
||||
|
||||
model = generate()
|
||||
raise HTTPException(404)
|
||||
|
||||
SESSION_CACHE[request.session_uuid] = model
|
||||
logging.info(f"Buid new session: {request.session_uuid}")
|
||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
||||
|
||||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||||
|
|
|
@ -43,6 +43,43 @@ async function sendMessage() {
|
|||
hideTypingIndicator();
|
||||
}
|
||||
|
||||
async function generateSession() {
|
||||
try {
|
||||
const response = await fetch("./generate_session", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ session_uuid }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Error:", response.statusText);
|
||||
return "Error occurred while generating response.";
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let isFinished = false;
|
||||
let generatedTextContent = "";
|
||||
|
||||
while (!isFinished) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
isFinished = true;
|
||||
break;
|
||||
}
|
||||
generatedTextContent += decoder.decode(value, {stream: true});
|
||||
}
|
||||
|
||||
return generatedTextContent;
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
return "An error occurred.";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async function generateText(prompt) {
|
||||
try {
|
||||
const response = await fetch("./generate_text_stream", {
|
||||
|
@ -116,3 +153,4 @@ function handleKeyPress(event) {
|
|||
}
|
||||
|
||||
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
|
||||
generateSession()
|
||||
|
|
Loading…
Reference in a new issue