diff --git a/app.py b/app.py index 65cd993..109bea0 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,7 @@ import os import vertexai import logging -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel @@ -18,16 +18,15 @@ SESSION_CACHE = LRUCache(15) logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) -def generate(): - vertexai.init(project=PROJECT_ID, location="us-central1") - model = GenerativeModel( - "gemini-1.5-flash-001", - # system_instruction=[textsi_1] - ) +def generate(model: str, instruction: str): + vertexai.init(project=PROJECT_ID, location="us-central1") + session = GenerativeModel( + model, + system_instruction=[instruction] + ) - return model.start_chat() + return session.start_chat() -# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议""" generation_config = { "max_output_tokens": 1024, @@ -73,6 +72,24 @@ app = FastAPI( ) +class SessionRequest(BaseModel): + instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\ +回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答。" + session_uuid: str + model: str = 'gemini-1.5-flash-001' + + +@app.post("/generate_session") +async def generate_session(request: SessionRequest): + if SESSION_CACHE.get(request.session_uuid): + raise HTTPException(500) + + model = generate(request.model, request.instruction) + + SESSION_CACHE[request.session_uuid] = model + logging.info(f"Buid new session: {request.session_uuid}") + + class ChatRequest(BaseModel): prompt: str session_uuid: str @@ -80,13 +97,10 @@ class ChatRequest(BaseModel): @app.post("/generate_text_stream") async def generate_text(request: ChatRequest): - if model := SESSION_CACHE.get(request.session_uuid): - return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") + if session := SESSION_CACHE.get(request.session_uuid): + return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain") - model = generate() + raise HTTPException(404) - SESSION_CACHE[request.session_uuid] = model - logging.info(f"Buid new session: {request.session_uuid}") - return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") app.mount("/", StaticFiles(directory="static", html=True), name="static") diff --git a/static/script.js b/static/script.js index 79562a8..73447cc 100644 --- a/static/script.js +++ b/static/script.js @@ -43,6 +43,43 @@ async function sendMessage() { hideTypingIndicator(); } +async function generateSession() { + try { + const response = await fetch("./generate_session", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ session_uuid }), + }); + + if (!response.ok) { + console.error("Error:", response.statusText); + return "Error occurred while generating response."; + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let isFinished = false; + let generatedTextContent = ""; + + while (!isFinished) { + const { done, value } = await reader.read(); + if (done) { + isFinished = true; + break; + } + generatedTextContent += decoder.decode(value, {stream: true}); + } + + return generatedTextContent; + } catch (error) { + console.error("Error:", error); + return "An error occurred."; + } +} + + async function generateText(prompt) { try { const response = await fetch("./generate_text_stream", { @@ -116,3 +153,4 @@ function handleKeyPress(event) { } window.onload = () => addMessage("This session is: " + session_uuid, 'bot'); +generateSession()