"""
数学题目查重助手 - 使用向量相似度比对题目是否重复
"""
import json
import os
import pickle
import re
import numpy as np
import faiss
import pymysql
import concurrent.futures
from openai import OpenAI
from typing import List, Dict, Tuple, Optional
from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL
class QuestionDuplicateChecker:
"""题目查重器"""
def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"):
"""
初始化查重器
Args:
index_path: FAISS索引保存路径
metadata_path: 元数据保存路径
"""
self.client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL
)
self.index_path = index_path
self.metadata_path = metadata_path
self.index = None
self.metadata = [] # 存储题目ID和文本,以便回显
# 权重设置
self.weights = {
'stem': 0.8,
'options': 0.05,
'answer': 0.05,
'solution': 0.1
}
# 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
self.dimension = 3072 * 4
self._load_index()
def clean_text(self, text: str) -> str:
"""清理文本:移除图片标签、归一化空白字符"""
if not text:
return ""
# 移除 标签
text = re.sub(r']*/>', '', text)
# 归一化空白:将多个空格、换行替换为单个空格
text = re.sub(r'\s+', ' ', text)
return text.strip()
def normalize_options(self, options) -> str:
"""标准化选项内容"""
if not options:
return ""
data = None
if isinstance(options, str):
# 如果是字符串,尝试解析 JSON
text = options.strip()
if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')):
try:
data = json.loads(text)
except:
pass
if data is None:
return self.clean_text(options)
else:
data = options
if isinstance(data, dict):
# 排序并清理字典内容
sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())}
return self.clean_text(json.dumps(sorted_data, ensure_ascii=False))
elif isinstance(data, list):
# 清理列表内容
cleaned_data = [self.clean_text(str(v)) for v in data]
return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False))
return self.clean_text(str(data))
def _load_index(self):
"""加载或初始化索引"""
if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
try:
self.index = faiss.read_index(self.index_path)
# 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级)
if self.index.d != self.dimension:
print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...")
self._init_new_index()
else:
with open(self.metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
except Exception as e:
print(f"⚠️ 加载索引失败: {e},将初始化新索引")
self._init_new_index()
else:
self._init_new_index()
def _init_new_index(self):
"""初始化新的FAISS索引"""
# 使用内积索引,配合权重拼接向量
self.index = faiss.IndexFlatIP(self.dimension)
self.metadata = []
print("✓ 已初始化新的FAISS索引")
def save_index(self):
"""保存索引和元数据"""
faiss.write_index(self.index, self.index_path)
with open(self.metadata_path, 'wb') as f:
pickle.dump(self.metadata, f)
print(f"✓ 索引和元数据已保存到 {self.index_path}")
def get_weighted_embedding(self, question: Dict) -> np.ndarray:
"""
获取加权拼接向量
计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接
"""
parts = {
'stem': self.clean_text(question.get('stem', '') or ''),
'options': self.normalize_options(question.get('options', '') or ''),
'answer': self.clean_text(question.get('answer', '') or ''),
'solution': self.clean_text(question.get('solution', '') or '')
}
embeddings = []
for key, text in parts.items():
# 获取单部分 embedding
emb = self.get_embedding(text if text.strip() else "无") # 防止空文本
if emb is None:
return None
# 归一化该部分的向量
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
# 乘以权重平方根:这样点积结果就是权重加权和
weight_factor = np.sqrt(self.weights[key])
embeddings.append(emb * weight_factor)
# 拼接成一个大向量
mega_vector = np.concatenate(embeddings).astype('float32')
return mega_vector
def get_embedding(self, text: str) -> np.ndarray:
"""获取文本的embedding向量"""
try:
# 统一清理文本
cleaned_text = self.clean_text(text)
if not cleaned_text.strip():
cleaned_text = "无"
response = self.client.embeddings.create(
model="text-embedding-3-large",
input=cleaned_text
)
embedding = np.array(response.data[0].embedding).astype('float32')
return embedding
except Exception as e:
print(f"❌ 获取embedding失败: {e}")
return None
def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
"""批量获取向量化结果"""
try:
if not texts:
return []
# 统一清理文本
cleaned_texts = []
for t in texts:
c = self.clean_text(t)
cleaned_texts.append(c if c.strip() else "无")
response = self.client.embeddings.create(
model="text-embedding-3-large",
input=cleaned_texts
)
return [np.array(data.embedding).astype('float32') for data in response.data]
except Exception as e:
print(f"❌ 批量获取 embedding 失败: {e}")
return []
def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]:
"""
批量获取加权拼接向量
"""
all_texts = []
for q in questions:
# 获取 4 个部分并处理空值占位
stem = self.clean_text(q.get('stem', '') or '')
opts = self.normalize_options(q.get('options', '') or '')
ans = self.clean_text(q.get('answer', '') or '')
sol = self.clean_text(q.get('solution', '') or '')
# OpenAI API 不接受空字符串,强制填充 "无"
all_texts.append(stem if stem.strip() else "无")
all_texts.append(opts if opts.strip() else "无")
all_texts.append(ans if ans.strip() else "无")
all_texts.append(sol if sol.strip() else "无")
# 批量请求 (每题 4 个部分)
all_embeddings = self.get_embeddings_batch(all_texts)
if not all_embeddings:
return []
results = []
for i in range(len(questions)):
q_embs = all_embeddings[i*4 : (i+1)*4]
if len(q_embs) < 4:
results.append(None)
continue
weighted_parts = []
keys = ['stem', 'options', 'answer', 'solution']
for j, emb in enumerate(q_embs):
# 归一化
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
# 加权
weight_factor = np.sqrt(self.weights[keys[j]])
weighted_parts.append(emb * weight_factor)
mega_vector = np.concatenate(weighted_parts).astype('float32')
results.append(mega_vector)
return results
def combine_content(self, q: Dict) -> str:
"""组合题目内容用于预览显示"""
stem = q.get('stem', '') or ''
return self.clean_text(stem) # 预览以清理后的题干为主
def fetch_question_from_db(self, question_id: int) -> Optional[Dict]:
"""从数据库获取题目信息"""
try:
conn = pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USERNAME,
password=DB_PASSWORD,
database=DB_DATABASE,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
with conn.cursor() as cursor:
sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s"
cursor.execute(sql, (question_id,))
result = cursor.fetchone()
return result
except Exception as e:
print(f"❌ 数据库查询失败: {e}")
return None
finally:
if 'conn' in locals() and conn:
conn.close()
def confirm_repeat(self, question_id: int, is_repeat: int) -> bool:
"""
人工确认结果后更新数据库和向量库
is_repeat: 0 代表无相似题,1 代表有重复
"""
# 1. 更新数据库状态
self.update_is_repeat(question_id, is_repeat)
# 2. 如果标记为非重复,将其加入向量库
if is_repeat == 0:
question = self.fetch_question_from_db(question_id)
if question:
mega_vec = self.get_weighted_embedding(question)
if mega_vec is not None:
# 转换维度以便 FAISS 接收
mega_vec = mega_vec.reshape(1, -1)
self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100])
return True
return True
def update_is_repeat(self, question_id: int, is_repeat: int):
"""更新数据库中的 is_repeat 字段"""
try:
conn = pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USERNAME,
password=DB_PASSWORD,
database=DB_DATABASE,
charset='utf8mb4'
)
with conn.cursor() as cursor:
sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s"
cursor.execute(sql, (is_repeat, question_id))
conn.commit()
print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}")
except Exception as e:
print(f"❌ 更新 is_repeat 失败: {e}")
finally:
if 'conn' in locals() and conn:
conn.close()
def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool:
"""
调用 GPT-4o 进行深度对比,判断两道题是否为重复题
"""
try:
prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。
“重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。
题目 1:
题干: {q1.get('stem', '')}
选项: {q1.get('options', '')}
答案: {q1.get('answer', '')}
解析: {q1.get('solution', '')}
题目 2:
题干: {q2.get('stem', '')}
选项: {q2.get('options', '')}
答案: {q2.get('answer', '')}
解析: {q2.get('solution', '')}
请分析两道题的相似性,并给出最终结论。
结论必须以 JSON 格式输出,包含以下字段:
- is_duplicate: 布尔值,True 表示是重复题,False 表示不是
- reason: 简短的分析理由
只输出 JSON 即可,不要有其他解释说明。"""
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "你是一个专业的数学题目查重助手。"},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0
)
result = json.loads(response.choices[0].message.content)
is_dup = result.get('is_duplicate', False)
reason = result.get('reason', '')
print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}")
return is_dup
except Exception as e:
print(f"❌ GPT-4o 兜底计算失败: {e}")
return False
def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict:
"""
查重主逻辑 (未入库模式:不更新数据库,不自动入库)
"""
# 1. 获取题目信息
question = self.fetch_question_from_db(question_id)
if not question:
return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"}
# 2. 获取加权拼接向量
mega_vector = self.get_weighted_embedding(question)
if mega_vector is None:
return {"status": "error", "message": "无法生成题目向量"}
# 3. 搜索索引
if self.index.ntotal == 0:
return {
"status": "success",
"message": "向量库为空",
"is_duplicate": False,
"top_similar": []
}
# 搜索最相似的题目
k = min(3, self.index.ntotal)
query_vec = mega_vector.reshape(1, -1)
similarities, indices = self.index.search(query_vec, k)
best_similar = None
is_duplicate = False
gpt_checked = False
for i in range(k):
idx = indices[0][i]
score = float(similarities[0][i])
if idx != -1:
similar_info = self.metadata[idx]
similar_id = similar_info['id']
if similar_id == question_id:
continue
best_similar = {
"id": similar_id,
"similarity": round(score, 4),
"similar_point": similar_info.get('text_preview', '')
}
if score >= threshold:
is_duplicate = True
elif score > 0.5:
# 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
similar_question = self.fetch_question_from_db(similar_id)
if similar_question:
is_duplicate = self.call_openai_for_duplicate_check(question, similar_question)
gpt_checked = True
break
if is_duplicate:
return {
"status": "warning",
"message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""),
"is_duplicate": True,
"gpt_checked": gpt_checked,
"top_similar": [best_similar]
}
else:
return {
"status": "success",
"message": "该题目通过查重",
"is_duplicate": False,
"top_similar": [best_similar] if best_similar else []
}
def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict:
"""
基于原始文本内容进行查重 (预检模式)
"""
# 1. 获取加权拼接向量
mega_vector = self.get_weighted_embedding(question_data)
if mega_vector is None:
return {"status": "error", "message": "无法生成题目向量"}
# 2. 检索
if self.index.ntotal == 0:
return {"status": "success", "is_duplicate": False, "top_similar": []}
k = 1 # 预检通常只需返回最相似的一个
query_vec = mega_vector.reshape(1, -1)
similarities, indices = self.index.search(query_vec, k)
if indices[0][0] != -1:
score = float(similarities[0][0])
similar_info = self.metadata[indices[0][0]]
similar_id = similar_info['id']
best_similar = {
"id": similar_id,
"similarity": round(score, 4),
"similar_point": similar_info.get('text_preview', '')
}
is_duplicate = False
gpt_checked = False
if score >= threshold:
is_duplicate = True
elif score > 0.5:
# 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
similar_question = self.fetch_question_from_db(similar_id)
if similar_question:
is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question)
gpt_checked = True
if is_duplicate:
return {
"status": "warning",
"is_duplicate": True,
"gpt_checked": gpt_checked,
"top_similar": [best_similar]
}
return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]}
return {"status": "success", "is_duplicate": False, "top_similar": []}
def get_question_data(self, question_id: int) -> Dict:
"""获取向量库中特定ID的数据及总数"""
total_count = self.index.ntotal if self.index else 0
target_metadata = None
for m in self.metadata:
if m['id'] == question_id:
target_metadata = m
break
return {
"total_count": total_count,
"question_data": target_metadata
}
def add_to_index(self, question_id: int, vector: np.ndarray, text: str):
"""将题目加入索引"""
if any(m['id'] == question_id for m in self.metadata):
return
self.index.add(vector)
self.metadata.append({
'id': question_id,
'text_preview': self.clean_text(text)[:100]
})
self.save_index()
def sync_all_from_db(self, batch_size=50, max_workers=5):
"""同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
print("🔄 开始全量同步 (优化版 - 加权模式)...")
existing_ids = {m['id'] for m in self.metadata}
try:
conn = pymysql.connect(
host=DB_HOST,
port=DB_PORT,
user=DB_USERNAME,
password=DB_PASSWORD,
database=DB_DATABASE,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
with conn.cursor() as cursor:
print("📡 正在从数据库读取所有题目数据...")
sql = "SELECT id, stem, options, answer, solution FROM questions_tem"
cursor.execute(sql)
all_questions = cursor.fetchall()
print(f"📦 数据库加载完成,共计 {len(all_questions)} 条记录")
new_questions = [q for q in all_questions if q['id'] not in existing_ids]
total_new = len(new_questions)
if total_new == 0:
print("✅ 已经是最新状态,无需同步。")
return
print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
# 分块处理
chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)]
def process_chunk(chunk):
mega_vectors = self.get_weighted_embeddings_batch(chunk)
return chunk, mega_vectors
# 使用线程池并发
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
count = 0
for future in concurrent.futures.as_completed(future_to_chunk):
chunk, mega_vectors = future.result()
if mega_vectors:
for i, mega_vec in enumerate(mega_vectors):
if mega_vec is not None:
self.index.add(mega_vec.reshape(1, -1))
self.metadata.append({
'id': chunk[i]['id'],
'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100]
})
count += len(chunk)
print(f"✅ 已同步进度: {count}/{total_new}")
# 每 500 条保存一次
if count % 500 == 0:
self.save_index()
self.save_index()
print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
except Exception as e:
print(f"❌ 同步失败: {e}")
import traceback
traceback.print_exc()
finally:
if 'conn' in locals() and conn:
conn.close()
if __name__ == "__main__":
# 测试代码
checker = QuestionDuplicateChecker()
# checker.sync_all_from_db() # 首次运行时同步
# result = checker.check_duplicate(10)
# print(json.dumps(result, ensure_ascii=False, indent=2))