test_extract.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import requests
  2. import json
  3. import re
  4. import io
  5. import base64
  6. from PIL import Image
  7. def get_normalized_base64_image(image_url):
  8. try:
  9. response = requests.get(image_url, timeout=30)
  10. response.raise_for_status()
  11. with Image.open(io.BytesIO(response.content)) as img:
  12. if img.mode != 'RGB':
  13. img = img.convert('RGB')
  14. max_dim = 2000
  15. if max(img.width, img.height) > max_dim:
  16. ratio = max_dim / max(img.width, img.height)
  17. new_size = (int(img.width * ratio), int(img.height * ratio))
  18. img = img.resize(new_size, Image.Resampling.LANCZOS)
  19. buffer = io.BytesIO()
  20. img.save(buffer, format='JPEG', quality=85)
  21. b64_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  22. return f'data:image/jpeg;base64,{b64_str}'
  23. except Exception as e:
  24. print(f'Error normalizing image: {e}')
  25. return image_url
  26. def call_doubao_image_api(image_url, prompt):
  27. api_key = 'a1800657-9212-4afe-9b7c-b49f015c54d3'
  28. api_url = 'https://ark.cn-beijing.volces.com/api/v3/responses'
  29. ai_payload_url = get_normalized_base64_image(image_url)
  30. payload = {
  31. 'model': 'doubao-seed-1-8-251228',
  32. 'stream': False,
  33. 'input': [
  34. {
  35. 'role': 'user',
  36. 'content': [
  37. {'type': 'input_image', 'image_url': ai_payload_url},
  38. {'type': 'input_text', 'text': prompt}
  39. ]
  40. }
  41. ]
  42. }
  43. headers = {
  44. 'Authorization': f'Bearer {api_key}',
  45. 'Content-Type': 'application/json'
  46. }
  47. try:
  48. response = requests.post(
  49. api_url,
  50. json=payload,
  51. headers=headers,
  52. timeout=120,
  53. verify=False,
  54. proxies={'http': None, 'https': None}
  55. )
  56. if response.status_code == 200:
  57. return response.json()
  58. else:
  59. print(f'API Error: {response.status_code}')
  60. return None
  61. except Exception as e:
  62. print(f'Exception: {e}')
  63. return None
  64. def extract_text_from_response(response):
  65. """从API响应中提取文本内容"""
  66. if not response:
  67. return ''
  68. # 尝试多种响应格式
  69. if 'output' in response:
  70. for item in response['output']:
  71. # 跳过reasoning类型
  72. if item.get('type') == 'reasoning':
  73. continue
  74. content = item.get('content')
  75. if isinstance(content, str):
  76. return content
  77. elif isinstance(content, list):
  78. text_parts = []
  79. for part in content:
  80. if isinstance(part, dict):
  81. if part.get('type') == 'text':
  82. text_parts.append(part.get('text', ''))
  83. elif part.get('type') == 'reasoning':
  84. continue
  85. elif isinstance(part, str):
  86. text_parts.append(part)
  87. return ''.join(text_parts)
  88. if 'choices' in response and len(response['choices']) > 0:
  89. message = response['choices'][0].get('message', {})
  90. return message.get('content', '')
  91. return str(response)
  92. def clean_text(text):
  93. """清理文本,去除多余内容"""
  94. if not text:
  95. return ''
  96. text = text.strip()
  97. # 去除代码块标记
  98. if text.startswith('```json'):
  99. text = text[7:]
  100. if text.startswith('```'):
  101. text = text[3:]
  102. if text.endswith('```'):
  103. text = text[:-3]
  104. text = text.strip()
  105. # 尝试解析JSON
  106. try:
  107. result = json.loads(text)
  108. if isinstance(result, dict):
  109. # 尝试多种可能的字段名
  110. for key in ['genealogy_traditional', 'traditional', 'text', 'content', 'result']:
  111. if key in result:
  112. text = str(result[key])
  113. break
  114. except json.JSONDecodeError:
  115. pass
  116. # 去除解释性文字
  117. unwanted_patterns = [
  118. '请分析', '要求', '提取', '转换', '繁体', '简体',
  119. 'genealogy', 'traditional', 'simplified',
  120. '原始', '原文', 'JSON', '格式', '输出',
  121. 'reasoning', 'thinking', '思考', '分析',
  122. '我现在需要', '首先', '然后', '接下来',
  123. '根据图片', '图片中', '识别', 'OCR'
  124. ]
  125. for pattern in unwanted_patterns:
  126. text = text.replace(pattern, '')
  127. # 去除JSON结构残留
  128. text = re.sub(r'["\']text["\']\s*[,:]\s*["\']', '', text)
  129. text = re.sub(r'["\']', '', text)
  130. # 提取纯中文
  131. chinese_text = re.findall(r'[\u4e00-\u9fff]+', text)
  132. if chinese_text:
  133. text = ''.join(chinese_text)
  134. return text.strip()
  135. # 测试不同的prompt
  136. prompts = [
  137. '提取图片中的繁体中文文字,直接输出,不要解释。',
  138. '识别图片中的竖排繁体中文,按阅读顺序输出原文。',
  139. 'OCR识别图片文字,只输出结果。',
  140. '读取图片中的族谱文字,直接返回。',
  141. '分析图片,提取所有中文文字,不要分析。'
  142. ]
  143. print('=== 测试不同Prompt效果 ===')
  144. for i, prompt in enumerate(prompts):
  145. print(f'\nPrompt {i+1}: {prompt}')
  146. print('-' * 50)
  147. # 这里需要实际的图片URL进行测试
  148. # 测试模式:打印prompt供参考
  149. print('(需要实际图片URL进行测试)')
  150. # 手动测试样例 - 根据用户提供的图片内容
  151. print('\n=== 预期提取结果(根据图片手动识别)===')
  152. print('因公图片原文(竖排繁体):')
  153. print('因公')
  154. print('字廷大授南州刺史上距陽公三十五世後漢延康元年二月初六日渡')
  155. print('婺州之金華縣長樂鄉 娶林氏生三子 塟藤就村見有石柱石人華表')