ソースを参照

1.改用ujson来反序列化,减少文章格式上的验证不通过。
2.tts接口,增加部分特殊符号的忽略。
3.修复bug,文章生成接口,对ai回复的文章校验。除了schema的校验,再加自己的中英文比例校验

xie 1 週間 前
コミット
3931b70531
6 ファイル変更59 行追加172 行削除
  1. 13 2
      gpt/chatgpt.py
  2. 4 3
      gpt/get_article2.py
  3. 20 145
      gpt/gpt_check.py
  4. 16 16
      main.py
  5. 4 4
      mock/mock_request.py
  6. 2 2
      tools/audio.py

+ 13 - 2
gpt/chatgpt.py

@@ -8,7 +8,7 @@ from tools.new_mysql import MySQLUploader
 from typing import Optional, Dict, Any,Union
 from gpt.gpt_check import Article,Annotation
 from pydantic import ValidationError
-
+import ujson
 
 m = MySQLUploader()
 
@@ -118,19 +118,30 @@ def get_article_gpt_pydantic(question, real_ip="localhost", demo_name="无", mod
     d2['messages'].append({"role": "user", "content": question})
 
     for num_count in range(3):
+        answer = ""
         try:
             response = requests.post('http://170.106.108.95/v1/chat/completions', json=d2)
             r_json = response.json() 
 
            
             for choice in r_json["choices"]:
-                Article.model_validate_json(choice["message"]["content"])
+               
+                answer = ujson.dumps(ujson.loads(choice["message"]["content"]),ensure_ascii=False)
+                Article.model_validate_json(answer)
+
+               
+                check_result = check_fucn(str(answer))
+                if not check_result: 
+                    logger.error(answer)
+                    raise Exception(f"忽略,第{num_count + 1}次共3次,GPT的校验没有通过,校验函数:{check_fucn.__name__}")
+
 
             simple_logger.info(f"问题日志task_id:{task_id},exercise_id:{exercise_id}\n回答日志:\n{r_json}")
             return r_json
 
         except ValidationError as e:
             logger.error(f"gpt回复校验失败task_id:{task_id},exercise_id:{exercise_id}:")
+            logger.error(answer)
 
         except requests.exceptions.RequestException as e:
             logger.error(f"HTTP请求错误task_id:{task_id},exercise_id:{exercise_id}: {str(e)}")

+ 4 - 3
gpt/get_article2.py

@@ -122,6 +122,8 @@ class GetArticle:
 
    
     def parser_insert_to_mysql(self, resp_result):
+        if not resp_result:
+            return None
         try:
             for single_article in resp_result['articles']:
                 article = single_article['body']
@@ -272,8 +274,6 @@ class GetArticle:
         shuffle(core_words)
         core_words_meaning_str = "; ".join([f"[{i['meaning_id']}  {i['spell']} {i['meaning']}]" for i in core_words])
 
-        no_escape_code = r"\\n\\n"
-
         sys_prompt = "你是一个专业的英语老师,擅长根据用户提供的词汇生成对应的英语文章和中文翻译和4个配套选择题。"
 
         q = f"""下面我会为你提供一组数据,[单词组](里面包含词义id,英语单词,中文词义),请根据这些单词的中文词义,\
@@ -283,8 +283,9 @@ class GetArticle:
 1.必须用提供的这个词义的单词,其他单词使用{select_diffculty}的单词。{desc2}{choice_desc}
 2.优先保证文章语句通顺,意思不要太生硬。不要为了使用特定的单词,造成文章语义前后不搭,允许不使用个别词义。
 3.文章中使用提供单词,一定要和提供单词的中文词义匹配,尤其是一词多义时,务必使用提供单词的词义。必须要用提供单词的词义。如果用到的词义与提供单词词义不一致,请不要使用这个单词。
-4.生成的文章要求{article_length}词左右,可以用{no_escape_code}字符分段,一般{select_paragraph_count}个段落左右。第一段是文章标题。不需要markdown格式。
+4.生成的文章要求{article_length}词左右,一般{select_paragraph_count}个段落左右。第一段是文章标题。不需要markdown格式。
 5.允许不使用[单词组]的个别单词,优先保证文章整体意思通顺连贯和故事完整。
+6.回复的标准json字符串紧凑,能被直接json.loads()解析。
 
 提供[单词组]:{core_words_meaning_str};
 """

ファイルの差分が大きいため隠しています
+ 20 - 145
gpt/gpt_check.py


+ 16 - 16
main.py

@@ -1,20 +1,21 @@
 # -*- coding: utf-8 -*-
 import time
-from fastapi import FastAPI, Request
-from fastapi.responses import PlainTextResponse
 from threading import Thread
-from typing import Callable
 
-from core.api_article_annotation import router_article_annotation as r7
+from fastapi import FastAPI,Request
+from fastapi.responses import PlainTextResponse
+from typing import Callable
 from core.api_get_article import router as r1
-from core.api_get_article2 import router as r3
-from core.api_get_article3 import router as r6
 from core.api_get_audio import router as r2
-from core.api_get_spoken_language import router as r5
+from core.api_get_article2 import router as r3
 from core.api_get_word import router as r4
-from core.respone_format import *
+from core.api_get_spoken_language import router as r5
+from core.api_get_article3 import router as r6
+from core.api_article_annotation import router_article_annotation as r7
+
+from tools.loglog import logger,log_err_e
 from tools.del_expire_file import run_del_normal
-from tools.loglog import logger, log_err_e
+from core.respone_format import *
 
 app = FastAPI(title="AI相关功能接口", version="1.1")
 
@@ -27,7 +28,6 @@ app.include_router(r5, tags=["口语评测"])
 app.include_router(r6, tags=["deepseek文章"])
 app.include_router(r7, tags=["文章词义标注"])
 
-
 @app.middleware("http")
 async def add_process_time_header(request: Request, call_next: Callable):
     start_time = time.time()
@@ -37,30 +37,30 @@ async def add_process_time_header(request: Request, call_next: Callable):
     try:
         body = await request.json() if request.method in ["POST", "PUT", "PATCH"] else ""
     except:
-        body = ""
+        body =""
     logger.info(f"\n正式接口请求:{real_ip} {request.method} {path}\n查询参数:{params}\n携带参数:{body}")
 
     try:
         response = await call_next(request)
     except Exception as e:
-        log_err_e(e, msg="http中间件错误捕捉")
+        log_err_e(e,msg="http中间件错误捕捉")
         return resp_500(message=f"{type(e).__name__},{e}")
 
     process_time = str(round(time.time() - start_time, 2))
     response.headers["X-Process-Time"] = process_time
 
-    if path not in ['/', '/tts']:
-        with open('log/time_log.txt', encoding='utf-8', mode='a') as f:
+   
+    if path not in ['/','/tts']:
+        with open('log/time_log.txt', encoding='utf-8', mode='a')as f:
             t = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
             f.write(f"{t}  路径:{path} - 用时:{process_time}\n")
     return response
 
-
 @app.get("/")
 @app.post("/")
 def hello():
     return PlainTextResponse("hello world")
 
-
 del_file_thread = Thread(target=run_del_normal, daemon=True)
 del_file_thread.start()
+

+ 4 - 4
mock/mock_request.py

@@ -19,7 +19,7 @@ test_address2 = "http://111.231.167.191:8003"
 
 local_adress = "http://127.0.0.1:9000" 
 
-use_address = product_adress 
+use_address = test_address 
 
 
 class DifficultSentence(BaseModel):
@@ -404,7 +404,7 @@ Although peer pressure is sometimes quite obvious, it can also be so ____13____
     r = requests.post(f"{use_address}/article/meaning/annotation", json=json_data)
     r_json = r.json()
     assert r.status_code == 200 and r_json.get("code")==200
-    print(r_json)
+   
 
 
 
@@ -414,7 +414,7 @@ if __name__ == '__main__':
    
    
 
-    article_annotation()
+   
 
 
    
@@ -423,6 +423,6 @@ if __name__ == '__main__':
    
 
    
-   
+    get_audio()
 
    

ファイルの差分が大きいため隠しています
+ 2 - 2
tools/audio.py


この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません