ds_api.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # -*- coding: utf-8 -*-
  2. import json
  3. from openai import OpenAI
  4. import os
  5. from tools.loglog import SimpleLogger
  6. class DS:
  7. def __init__(self):
  8. self.client = OpenAI(
  9. api_key=os.getenv("DASHSCOPE_API_KEY"),
  10. base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
  11. )
  12. self.logger = SimpleLogger(base_file_name="deepseek")
  13. def write_log(self, message: str, log_type="info"):
  14. """写入日志"""
  15. log_methods = {
  16. "warning": self.logger.warning,
  17. "error": self.logger.error,
  18. "info": self.logger.info
  19. }
  20. log_methods.get(log_type, self.logger.info)(message=message)
  21. def check_article_response(self, response: str) -> bool:
  22. """校验文章的回复是否符合预期格式"""
  23. try:
  24. resp_json = json.loads(response)
  25. required_fields = ["english", "chinese", "difficultSentences"]
  26. return all(field in resp_json for field in required_fields)
  27. except Exception as e:
  28. self.write_log(f"Response validation error: {e}", log_type="error")
  29. return False
  30. def get_article(self, user_prompt: str, sys_prompt: str = None, temperature: float = 0.8,
  31. json_resp: bool = False, real_ip: str = "", demo_name: str = "",
  32. max_tokens: int = 5192) -> str:
  33. """获取AI生成的文章
  34. Args:
  35. user_prompt: 用户输入的提示词
  36. sys_prompt: 系统提示词
  37. temperature: 温度参数,控制输出的随机性
  38. json_resp: 是否返回JSON格式
  39. real_ip: 用户IP
  40. demo_name: 演示名称
  41. max_tokens: 最大token数
  42. Returns:
  43. str: AI生成的回复内容
  44. """
  45. messages = []
  46. if sys_prompt:
  47. messages.append({'role': 'system', 'content': sys_prompt})
  48. messages.append({'role': 'user', 'content': user_prompt})
  49. response_format = {"type": "json_object"} if json_resp else {"type": "text"}
  50. resp = ""
  51. for _ in range(3):
  52. completion = self.client.chat.completions.create(
  53. model="deepseek-v3",
  54. messages=messages,
  55. temperature=temperature,
  56. response_format=response_format,
  57. max_tokens=max_tokens
  58. )
  59. resp = completion.choices[0].message.content
  60. if self.check_article_response(resp):
  61. break
  62. if sys_prompt and resp:
  63. self.write_log(sys_prompt)
  64. self.write_log(user_prompt)
  65. self.write_log(resp)
  66. return resp
  67. if __name__ == '__main__':
  68. os.chdir('..')
  69. p = """下面我会为你提供两组数据,[单词组1]和[单词组2](里面包含词义id,英语单词,中文词义),优先使用[单词组1]内的单词,请根据这些单词的中文词义,生成一篇带中文翻译的考场英语文章,英语文章和中文翻译要有[标题]。注意这个单词有多个词义时,生成的英语文章一定要用提供的中文词义。并挑选一句复杂的句子和其中文翻译,放入difficultSentences。英语文章,放入"englishArticle"中。中文翻译,放入"chineseArticle"中。最终文中使用到的单词id放入"usedMeanIds"中。4个选择题,放入questions字段。questions结构下有4个选择题对象,其中trunk是[英语]问题文本,analysis是[中文]的问题分析,candidates是4个ABCD选项,内部有label是指选项序号A B C D ,text是[英语]选项文本,isRight是否正确答案1是正确0是错误。
  70. 提供[单词组1]:847 protect:保护;592 bear:出生, 结果;601 lie:位于;431 close:近, 靠近;1031 direction:方向;1282 coffee:咖啡豆;303 once:曾经;827 raise:养育;373 follow:听懂, 领会;1286 solve:解决, 解答;
  71. 提供[单词组2]:1288 destroy:破坏, 摧毁;1290 project:放映, 展现;1292 waste:浪费, 荒芜, 废物;1293 environment:环境, 外界;1294 memory:记忆, 记忆力, 回忆;
  72. 要求:
  73. 1.必须用提供的这个词义的单词,其他单词使用常见、高中难度的的单词。文章整体难度适中,大约和中国的高中生,中国CET-6,雅思6分这样的难度标准。
  74. 2.优先保证文章语句通顺,意思不要太生硬。不要为了使用特定的单词,造成文章语义前后不搭,允许不使用个别词义。
  75. 3.文章中使用提供单词,一定要和提供单词的中文词义匹配,尤其是一词多义时,务必使用提供单词的词义。必须要用提供单词的词义。如果用到的词义与提供单词词义不一致,请不要使用这个单词。
  76. 4.生成的文章要求500词左右,可以用\n\n字符分段,一般5个段落左右。第一段是文章标题。
  77. 5.生成文章优先使用[单词组1]的词义,其次可以挑选使用[单词组2]的词义。允许不使用[单词组1]的个别单词,优先保证文章整体意思通顺连贯和故事完整。
  78. 6.回复标准json数据,示例:
  79. {"difficultSentences":[{"english":"string","chinese":"string"}],"usedMeanIds":[0,0,0],"englishArticle":"string","chineseArticle":"string","questions":[{"trunk":"string","analysis":"string","candidates":[{"label":"string","text":"string","isRight":0}]}]}
  80. """
  81. ds = DS()
  82. resp = ds.get_article(user_prompt=p, json_resp=True)
  83. print(resp)
  84. print()
  85. print(resp.replace(r'\"n', '\n').replace(r"\\n", '\n'))
  86. print()
  87. print(json.loads(resp))