Explorar o código

增加一个字段,articleWordAmount,文章单词数

notepad hai 3 semanas
pai
achega
1922d13167
Modificáronse 1 ficheiros con 16 adicións e 13 borrados
  1. 16 13
      gpt/get_article2.py

+ 16 - 13
gpt/get_article2.py

@@ -6,7 +6,7 @@ from tools.new_mysql import MySQLUploader
 from tools.loglog import logger, log_err_e
 from tools.thread_pool_manager import pool_executor
 from common.common_data import all_exchange_words
-from common.split_text import split_text_to_word
+from common.split_text import split_text_to_word, get_article_words_count
 
 from pydantic import BaseModel
 from cachetools import TTLCache
@@ -138,7 +138,7 @@ class GetArticle:
 
    
     async def submit_task(self, real_ip: str, core_words: list, take_count: int,
-                          demo_name: str,reading_level:int, article_length: int, exercise_id: int):
+                          demo_name: str, reading_level: int, article_length: int, exercise_id: int):
         """
         core_words: 词义数据组
         take_count: 取文章数量 (int类型,正常是2篇,最大8篇)
@@ -155,7 +155,7 @@ class GetArticle:
             self.real_ip_dict[task_id] = real_ip
             self.demo_name[task_id] = demo_name
 
-            resp_result = await self.run_task(core_words, task_id,take_count,reading_level,article_length)
+            resp_result = await self.run_task(core_words, task_id, take_count, reading_level, article_length)
             await self.parser_insert_to_mysql(resp_result) 
             logger.success(f"reading-comprehension 文章2任务完成。学案id:{exercise_id},taskid:{task_id}\n{resp_result}")
             return resp_result
@@ -166,7 +166,7 @@ class GetArticle:
 
    
     @retry(stop=stop_after_attempt(3), wait=wait_fixed(2), reraise=True)
-    async def get_article(self, core_words: list, task_id: int,reading_level, article_length) -> dict:
+    async def get_article(self, core_words: list, task_id: int, reading_level, article_length) -> dict:
        
         if not article_length:
             if 0 < reading_level <= 10:
@@ -176,7 +176,7 @@ class GetArticle:
             else:
                 article_length = 450 + 20 * (reading_level - 20)
 
-        for index,(start,end) in enumerate([(1,8),(9,16),(17,24),(24,30)],start=1):
+        for index, (start, end) in enumerate([(1, 8), (9, 16), (17, 24), (24, 30)], start=1):
             if start <= reading_level <= end:
                 difficulty_control_stage = index
                 break
@@ -236,11 +236,14 @@ class GetArticle:
                                                                check_fucn=CheckArticleResult.get_article_1, max_tokens=4000,
                                                                sys_prompt=sys_prompt, client=self.client))
            
-            allWordAmount = 0
-            allWordAmount += len(split_text_to_word(r_json["englishArticle"]))
+            allWordAmount = 0 
+           
+            articleWordAmount = get_article_words_count(r_json["englishArticle"])
+            allWordAmount += articleWordAmount
+
             for i in r_json["questions"]:
-                count_trunk = len(split_text_to_word(i["trunk"]))
-                count_candidates = sum([len(split_text_to_word(ii["text"])) for ii in i["candidates"]])
+                count_trunk = get_article_words_count(i["trunk"])
+                count_candidates = sum([get_article_words_count(ii["text"]) for ii in i["candidates"]])
                 allWordAmount += count_trunk
                 allWordAmount += count_candidates
 
@@ -275,7 +278,7 @@ class GetArticle:
                     candidate['label'] = labels[index]
                 q['candidates'] = shuffled_candidates
 
-            return {**r_json, "allWordAmount": allWordAmount}
+            return {**r_json, "allWordAmount": allWordAmount, "articleWordAmount": articleWordAmount}
         except httpx.HTTPError as e:
             logger.error(f"HTTP请求错误: {str(e)}")
             raise
@@ -287,7 +290,7 @@ class GetArticle:
             raise
 
    
-    async def run_get_article_task(self, core_words, task_id,take_count,reading_level,article_length) -> dict:
+    async def run_get_article_task(self, core_words, task_id, take_count, reading_level, article_length) -> dict:
         """
         :param core_words: 核心单词数据,优先级1;可能为空
         :param task_id: 任务id
@@ -315,9 +318,9 @@ class GetArticle:
             raise
 
    
-    async def run_task(self, core_words, task_id,take_count,reading_level,article_length):
+    async def run_task(self, core_words, task_id, take_count, reading_level, article_length):
         try:
-            outside_json = await self.run_get_article_task(core_words, task_id,take_count,reading_level,article_length)
+            outside_json = await self.run_get_article_task(core_words, task_id, take_count, reading_level, article_length)
             return outside_json
         except Exception as e:
             log_err_e(e, msg="外层总任务捕获错误")