gpt.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # -*- coding:utf-8 -*-
  2. if __name__ == '__main__':
  3. import os
  4. os.chdir("..")
  5. import requests
  6. import random
  7. import time
  8. from tools.loglog import logger, simple_logger
  9. from tools.new_mysql import MySQLUploader
  10. m = MySQLUploader()
  11. def insert_ip_token(ip, demo_name, gpt_content, prompt_tokens, completion_tokens, total_tokens):
  12. sql = "insert into consumer_token (ip,demo_name,gpt_content,prompt_tokens,completion_tokens,total_tokens) values (%s,%s,%s,%s,%s,%s)"
  13. m.execute_(sql, (ip, demo_name, str(gpt_content), prompt_tokens, completion_tokens, total_tokens))
  14. def get_answer_from_gpt(question, real_ip="localhost", demo_name="无", model="gpt-4o", max_tokens=3500, temperature: float = 0,
  15. json_resp=False, n=1, sys_prompt=None):
  16. if "3.5" in model or "3.5-turbo" in model or "3.5turbo" in model:
  17. model = "gpt-3.5-turbo"
  18. elif "4o" in model or "gpt4o" in model:
  19. model = "gpt-4o"
  20. elif "4turbo" in model or "4-turbo" in model:
  21. model = "gpt-4-turbo"
  22. d2 = {
  23. "model": model,
  24. "messages": [],
  25. "max_tokens": max_tokens,
  26. "temperature": temperature,
  27. 'n': n}
  28. if sys_prompt:
  29. d2['messages'].append({"role": "system", "content": sys_prompt})
  30. d2['messages'].append({"role": "user", "content": question})
  31. if json_resp is True:
  32. d2["response_format"] = {"type": "json_object"}
  33. elif json_resp is False:
  34. pass
  35. else:
  36. d2["response_format"] = json_resp
  37. for _ in range(3):
  38. try:
  39. response = requests.post(f'http://170.106.108.95/v1/chat/completions', json=d2)
  40. r_json = response.json()
  41. if r2 := r_json.get("choices", None):
  42. if n > 1:
  43. gpt_res = []
  44. for i in r2:
  45. gpt_res.append(i["message"]["content"])
  46. else:
  47. gpt_res = r2[0]["message"]["content"]
  48. gpt_content = str(gpt_res)
  49. prompt_tokens = r_json["usage"]["prompt_tokens"]
  50. completion_tokens = r_json["usage"]["completion_tokens"]
  51. total_tokens = r_json["usage"]["total_tokens"]
  52. insert_ip_token(real_ip, demo_name, gpt_content, prompt_tokens, completion_tokens, total_tokens)
  53. simple_logger.info(f"问题日志:\n{question}\n回答日志:\n{gpt_res}")
  54. return gpt_res
  55. elif r_json.get("message") == "IP address blocked":
  56. print("IP address blocked")
  57. raise Exception("IP address blocked")
  58. else:
  59. print(f"小错误:{question[:10]}")
  60. logger.error(response.text)
  61. except Exception as e:
  62. logger.info(f"小报错忽略{e}")
  63. time.sleep(10)
  64. logger.critical("get_answer_from_gpt 严重错误,3次后都失败了")
  65. def parse_gpt_phon_to_tuplelist(text: str) -> list:
  66. """解析gpt返回的音标数据"""
  67. result = []
  68. if not text:
  69. return []
  70. for i in text.split("\n"):
  71. ii = i.split("***")
  72. if len(ii) >= 3:
  73. result.append((ii[0].strip(), ii[1].strip(), ii[2].strip()))
  74. return result
  75. if __name__ == '__main__':
  76. pass
  77. resp = get_answer_from_gpt("hello", temperature=0.8, model='gpt-4o')
  78. print(resp)