chatgpt.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import time
  4. from typing import Dict, Any, Union
  5. import requests
  6. from pydantic import ValidationError
  7. from gpt.gpt_check import Article, Annotation
  8. from tools.loglog import logger, simple_logger, log_err_e, temp_logger
  9. from tools.new_mysql import MySQLUploader
  10. m = MySQLUploader()
  11. def get_openai_model(model_text: str):
  12. """模糊获得模型名"""
  13. if "3.5" in model_text or "3.5-turbo" in model_text or "3.5turbo" in model_text:
  14. model = "gpt-3.5-turbo"
  15. elif "4o" in model_text or "gpt4o" in model_text:
  16. model = "gpt-4o"
  17. elif "4turbo" in model_text or "4-turbo" in model_text:
  18. model = "gpt-4-turbo"
  19. else:
  20. model = "gpt-4o"
  21. return model
  22. def insert_ip_token(ip, demo_name, gpt_content, prompt_tokens, completion_tokens, total_tokens):
  23. sql = "insert into consumer_token (ip,demo_name,gpt_content,prompt_tokens,completion_tokens,total_tokens) values (%s,%s,%s,%s,%s,%s)"
  24. m.execute_(sql, (ip, demo_name, str(gpt_content), prompt_tokens, completion_tokens, total_tokens))
  25. def get_answer_from_gpt(question, real_ip="localhost", demo_name="无", model="gpt-4o", max_tokens=3500, temperature: float = 0,
  26. json_resp: Union[Dict[Any, Any], bool] = False, n=1, check_fucn=None, sys_prompt=None):
  27. model = get_openai_model(model)
  28. d2 = {"model": model, "messages": [], "max_tokens": max_tokens, "temperature": temperature, 'n': n}
  29. if sys_prompt:
  30. d2['messages'].append({"role": "system", "content": sys_prompt})
  31. d2['messages'].append({"role": "user", "content": question})
  32. if json_resp is True:
  33. d2["response_format"] = {"type": "json_object"}
  34. elif json_resp is False:
  35. pass
  36. else:
  37. d2["response_format"] = json_resp
  38. for num_count in range(3):
  39. try:
  40. response = requests.post(f'http://170.106.108.95/v1/chat/completions', json=d2)
  41. r_json = response.json()
  42. if r2 := r_json.get("choices", None):
  43. if n > 1:
  44. gpt_res = []
  45. for i in r2:
  46. gpt_res.append(i["message"]["content"])
  47. else:
  48. gpt_res = r2[0]["message"]["content"]
  49. gpt_content = str(gpt_res)
  50. prompt_tokens = r_json["usage"]["prompt_tokens"]
  51. completion_tokens = r_json["usage"]["completion_tokens"]
  52. total_tokens = r_json["usage"]["total_tokens"]
  53. insert_ip_token(real_ip, demo_name, gpt_content, prompt_tokens, completion_tokens, total_tokens)
  54. simple_logger.info(f"问题日志:\n{question}\n回答日志:\n{gpt_res}")
  55. if not check_fucn:
  56. return gpt_res
  57. check_result = check_fucn(str(gpt_res))
  58. if check_result:
  59. return gpt_res
  60. else:
  61. raise Exception(f"第{num_count + 1}次共3次,GPT的校验没有通过,校验函数:{check_fucn.__name__}")
  62. elif r_json.get("message") == "IP address blocked":
  63. print("IP address blocked")
  64. raise Exception("IP address blocked")
  65. else:
  66. print(f"小错误:{question[:10]}")
  67. logger.error(response.text)
  68. except Exception as e:
  69. logger.info(f"小报错忽略{e}")
  70. time.sleep(10)
  71. logger.critical("get_answer_from_gpt 严重错误,3次后都失败了")
  72. def get_article_gpt_pydantic(question, real_ip="localhost", demo_name="无", model="gpt-4.1", max_tokens=3500, temperature: float = 0, n=1,
  73. check_fucn=None, sys_prompt=None, task_id=0, exercise_id=0):
  74. """
  75. 异步获取文章
  76. :param question: 问题
  77. :param real_ip: 真实IP
  78. :param demo_name: 项目名称
  79. :param model: 模型名称
  80. :param max_tokens: 最大token数
  81. :param temperature: 温度
  82. :param n: 生成数量
  83. :param check_fucn: 校验函数
  84. :param sys_prompt: 系统提示
  85. :param task_id: 任务id
  86. :param exercise_id: 学案id
  87. :return: 文章内容
  88. """
  89. d2 = {"model": model, "messages": [], "max_tokens": max_tokens, "temperature": temperature, "n": n,
  90. "response_format": {'type': 'json_schema', 'json_schema': {'name': 'Article', 'schema': {'$defs': {'Candidate': {
  91. 'properties': {'label': {'allOf': [{'$ref': '#/$defs/Options'}], 'description': 'ABCD序号的一种', 'title': '序号'},
  92. 'text': {'description': '英文,ABCD选项的文本', 'title': '选项文本', 'type': 'string'},
  93. 'isRight': {'allOf': [{'$ref': '#/$defs/IsRight'}], 'description': '1是正确,0是错误', 'title': '是否是正确答案'}},
  94. 'required': ['label', 'text', 'isRight'], 'title': 'Candidate', 'type': 'object'}, 'DifficultSentence': {
  95. 'properties': {'english': {'description': '文章中的一句难句', 'title': '英语难句', 'type': 'string'},
  96. 'chinese': {'description': '对英语难句的翻译', 'title': '中文难句', 'type': 'string'}}, 'required': ['english', 'chinese'],
  97. 'title': 'DifficultSentence', 'type': 'object'}, 'IsRight': {'enum': [1, 0], 'title': 'IsRight', 'type': 'integer'},
  98. 'Options': {'enum': ['A', 'B', 'C', 'D'],
  99. 'title': 'Options',
  100. 'type': 'string'},
  101. 'Question': {'properties': {'trunk': {
  102. 'description': '用英语给出的选择题题目',
  103. 'title': '选择题题目', 'type': 'string'},
  104. 'analysis': {
  105. 'description': '中文,选择题的分析思路;不要给出答案的ABCD序号',
  106. 'title': '选择题分析',
  107. 'type': 'string'},
  108. 'candidates': {
  109. 'description': '一共4个选择题',
  110. 'items': {
  111. '$ref': '#/$defs/Candidate'},
  112. 'title': '选项对象',
  113. 'type': 'array'}},
  114. 'required': ['trunk',
  115. 'analysis',
  116. 'candidates'],
  117. 'title': 'Question',
  118. 'type': 'object'}},
  119. 'properties': {'difficultSentences': {
  120. 'description': '挑选一句难句对象',
  121. 'items': {'$ref': '#/$defs/DifficultSentence'},
  122. 'title': '难句对象', 'type': 'array'},
  123. 'usedMeanIds': {
  124. 'items': {'type': 'integer'},
  125. 'title': '用到的词义id',
  126. 'type': 'array'}, 'questions': {
  127. 'description': '针对英语文章的选择题',
  128. 'items': {'$ref': '#/$defs/Question'},
  129. 'title': '问题对象', 'type': 'array'},
  130. 'englishArticle': {
  131. 'description': '',
  132. 'title': '英语文章',
  133. 'type': 'string'},
  134. 'chineseArticle': {
  135. 'description': '',
  136. 'title': '中文翻译',
  137. 'type': 'string'}},
  138. 'required': ['difficultSentences', 'usedMeanIds',
  139. 'questions', 'englishArticle',
  140. 'chineseArticle'], 'title': 'Article',
  141. 'type': 'object'}}}
  142. }
  143. if sys_prompt:
  144. d2['messages'].append({"role": "system", "content": sys_prompt})
  145. d2['messages'].append({"role": "user", "content": question})
  146. for num_count in range(3):
  147. try:
  148. response = requests.post('http://170.106.108.95/v1/chat/completions', json=d2)
  149. r_json = response.json()
  150. for choice in r_json["choices"]:
  151. Article.model_validate_json(choice["message"]["content"])
  152. simple_logger.info(f"问题日志task_id:{task_id},exercise_id:{exercise_id}\n回答日志:\n{r_json}")
  153. return r_json
  154. except ValidationError as e:
  155. logger.error(f"gpt回复校验失败task_id:{task_id},exercise_id:{exercise_id}:")
  156. except requests.exceptions.RequestException as e:
  157. logger.error(f"HTTP请求错误task_id:{task_id},exercise_id:{exercise_id}: {str(e)}")
  158. time.sleep(1)
  159. except json.decoder.JSONDecodeError as e:
  160. if 'response' in locals() and response is not None:
  161. logger.error(f"json格式化错误task_id:{task_id},exercise_id:{exercise_id}:{response.text}")
  162. except Exception as e:
  163. log_err_e(e, f"其他错误task_id:{task_id},exercise_id:{exercise_id}")
  164. def get_annotation_gpt_pydantic(question, real_ip="localhost", demo_name="无", model="gpt-4.1", max_tokens=3500, temperature: float = 0, n=1,
  165. check_fucn=None, sys_prompt=None, task_id=0, exercise_id=0):
  166. """
  167. 异步获取文章
  168. :param question: 问题
  169. :param real_ip: 真实IP
  170. :param demo_name: 项目名称
  171. :param model: 模型名称
  172. :param max_tokens: 最大token数
  173. :param temperature: 温度
  174. :param n: 生成数量
  175. :param check_fucn: 校验函数
  176. :param sys_prompt: 系统提示
  177. :param task_id: 任务id
  178. :param exercise_id: 学案id
  179. :return: 标注内容
  180. """
  181. d2 = {"model": model, "messages": [], "max_tokens": max_tokens, "temperature": temperature, "n": n,
  182. "response_format": {'type': 'json_schema', 'json_schema': {'name': 'Annotation', 'schema': {'properties': {
  183. 'annotation_text': {'description': '对句子或文章的每个单词进行词义id的标注', 'examples': ['an[33] apple[123]'], 'title': '标注文本', 'type': 'string'}},
  184. 'required': ['annotation_text'],
  185. 'title': 'Annotation', 'type': 'object'}}}
  186. }
  187. if sys_prompt:
  188. d2['messages'].append({"role": "system", "content": sys_prompt})
  189. d2['messages'].append({"role": "user", "content": question})
  190. for num_count in range(3):
  191. try:
  192. response = requests.post('http://170.106.108.95/v1/chat/completions', json=d2)
  193. r_json = response.json()
  194. for choice in r_json["choices"]:
  195. Annotation.model_validate_json(choice["message"]["content"])
  196. temp_logger.info(f"日志task_id:{task_id},exercise_id:{exercise_id}:\n问题日志:\n{question}")
  197. simple_logger.info(f"日志task_id:{task_id},exercise_id:{exercise_id}:\n回答日志:\n{r_json}")
  198. return r_json
  199. except ValidationError as e:
  200. logger.error(f"gpt回复校验失败task_id:{task_id},exercise_id:{exercise_id}:")
  201. except requests.exceptions.RequestException as e:
  202. logger.error(f"HTTP请求错误task_id:{task_id},exercise_id:{exercise_id}: {str(e)}")
  203. time.sleep(1)
  204. except json.decoder.JSONDecodeError as e:
  205. if 'response' in locals() and response is not None:
  206. logger.error(f"json格式化错误task_id:{task_id},exercise_id:{exercise_id}:{response.text}")
  207. except Exception as e:
  208. log_err_e(e, f"其他错误task_id:{task_id},exercise_id:{exercise_id}")
  209. def parse_gpt_phon_to_tuplelist(text: str) -> list:
  210. """解析gpt返回的音标数据"""
  211. result = []
  212. if not text:
  213. return []
  214. for i in text.split("\n"):
  215. ii = i.split("***")
  216. if len(ii) >= 3:
  217. result.append((ii[0].strip(), ii[1].strip(), ii[2].strip()))
  218. return result