[feat] add generate_session api
This commit is contained in:
parent
9304e24c8c
commit
45ff987051
2 changed files with 67 additions and 15 deletions
44
app.py
44
app.py
|
@ -2,7 +2,7 @@
|
||||||
import os
|
import os
|
||||||
import vertexai
|
import vertexai
|
||||||
import logging
|
import logging
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -18,16 +18,15 @@ SESSION_CACHE = LRUCache(15)
|
||||||
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def generate():
|
def generate(model: str, instruction: str):
|
||||||
vertexai.init(project=PROJECT_ID, location="us-central1")
|
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||||
model = GenerativeModel(
|
session = GenerativeModel(
|
||||||
"gemini-1.5-flash-001",
|
model,
|
||||||
# system_instruction=[textsi_1]
|
system_instruction=[instruction]
|
||||||
)
|
)
|
||||||
|
|
||||||
return model.start_chat()
|
return session.start_chat()
|
||||||
|
|
||||||
# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议"""
|
|
||||||
|
|
||||||
generation_config = {
|
generation_config = {
|
||||||
"max_output_tokens": 1024,
|
"max_output_tokens": 1024,
|
||||||
|
@ -73,6 +72,24 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRequest(BaseModel):
|
||||||
|
instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\
|
||||||
|
回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答。"
|
||||||
|
session_uuid: str
|
||||||
|
model: str = 'gemini-1.5-flash-001'
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_session")
|
||||||
|
async def generate_session(request: SessionRequest):
|
||||||
|
if SESSION_CACHE.get(request.session_uuid):
|
||||||
|
raise HTTPException(500)
|
||||||
|
|
||||||
|
model = generate(request.model, request.instruction)
|
||||||
|
|
||||||
|
SESSION_CACHE[request.session_uuid] = model
|
||||||
|
logging.info(f"Buid new session: {request.session_uuid}")
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
session_uuid: str
|
session_uuid: str
|
||||||
|
@ -80,13 +97,10 @@ class ChatRequest(BaseModel):
|
||||||
|
|
||||||
@app.post("/generate_text_stream")
|
@app.post("/generate_text_stream")
|
||||||
async def generate_text(request: ChatRequest):
|
async def generate_text(request: ChatRequest):
|
||||||
if model := SESSION_CACHE.get(request.session_uuid):
|
if session := SESSION_CACHE.get(request.session_uuid):
|
||||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain")
|
||||||
|
|
||||||
model = generate()
|
raise HTTPException(404)
|
||||||
|
|
||||||
SESSION_CACHE[request.session_uuid] = model
|
|
||||||
logging.info(f"Buid new session: {request.session_uuid}")
|
|
||||||
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
|
||||||
|
|
||||||
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
||||||
|
|
|
@ -43,6 +43,43 @@ async function sendMessage() {
|
||||||
hideTypingIndicator();
|
hideTypingIndicator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function generateSession() {
|
||||||
|
try {
|
||||||
|
const response = await fetch("./generate_session", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ session_uuid }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.error("Error:", response.statusText);
|
||||||
|
return "Error occurred while generating response.";
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let isFinished = false;
|
||||||
|
let generatedTextContent = "";
|
||||||
|
|
||||||
|
while (!isFinished) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) {
|
||||||
|
isFinished = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
generatedTextContent += decoder.decode(value, {stream: true});
|
||||||
|
}
|
||||||
|
|
||||||
|
return generatedTextContent;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
return "An error occurred.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async function generateText(prompt) {
|
async function generateText(prompt) {
|
||||||
try {
|
try {
|
||||||
const response = await fetch("./generate_text_stream", {
|
const response = await fetch("./generate_text_stream", {
|
||||||
|
@ -116,3 +153,4 @@ function handleKeyPress(event) {
|
||||||
}
|
}
|
||||||
|
|
||||||
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
|
window.onload = () => addMessage("This session is: " + session_uuid, 'bot');
|
||||||
|
generateSession()
|
||||||
|
|
Loading…
Reference in a new issue