|
@@ -5,6 +5,9 @@ import json
|
|
|
import os
|
|
import os
|
|
|
import pickle
|
|
import pickle
|
|
|
import re
|
|
import re
|
|
|
|
|
+import threading
|
|
|
|
|
+import time
|
|
|
|
|
+import tempfile
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import faiss
|
|
import faiss
|
|
|
import pymysql
|
|
import pymysql
|
|
@@ -28,10 +31,19 @@ class QuestionDuplicateChecker:
|
|
|
api_key=OPENAI_API_KEY,
|
|
api_key=OPENAI_API_KEY,
|
|
|
base_url=OPENAI_BASE_URL
|
|
base_url=OPENAI_BASE_URL
|
|
|
)
|
|
)
|
|
|
- self.index_path = index_path
|
|
|
|
|
- self.metadata_path = metadata_path
|
|
|
|
|
|
|
+ base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
+ # 统一使用绝对路径,避免不同工作目录导致索引文件混乱
|
|
|
|
|
+ self.index_path = index_path if os.path.isabs(index_path) else os.path.join(base_dir, index_path)
|
|
|
|
|
+ self.metadata_path = metadata_path if os.path.isabs(metadata_path) else os.path.join(base_dir, metadata_path)
|
|
|
self.index = None
|
|
self.index = None
|
|
|
self.metadata = [] # 存储题目ID和文本,以便回显
|
|
self.metadata = [] # 存储题目ID和文本,以便回显
|
|
|
|
|
+ self.sync_lock = threading.Lock()
|
|
|
|
|
+ self.reload_lock = threading.Lock()
|
|
|
|
|
+ self.index_mtime = None
|
|
|
|
|
+ self.metadata_mtime = None
|
|
|
|
|
+ self.sync_in_progress = False
|
|
|
|
|
+ self.reload_failures = 0
|
|
|
|
|
+ self.next_reload_time = 0.0
|
|
|
|
|
|
|
|
# 权重设置
|
|
# 权重设置
|
|
|
self.weights = {
|
|
self.weights = {
|
|
@@ -42,6 +54,12 @@ class QuestionDuplicateChecker:
|
|
|
}
|
|
}
|
|
|
# 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
|
|
# 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
|
|
|
self.dimension = 3072 * 4
|
|
self.dimension = 3072 * 4
|
|
|
|
|
+ # 限制 FAISS/OpenMP 线程,避免小机器 CPU 过载
|
|
|
|
|
+ faiss_threads = max(1, min(2, os.cpu_count() or 1))
|
|
|
|
|
+ try:
|
|
|
|
|
+ faiss.omp_set_num_threads(faiss_threads)
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
self._load_index()
|
|
self._load_index()
|
|
|
|
|
|
|
@@ -98,6 +116,7 @@ class QuestionDuplicateChecker:
|
|
|
with open(self.metadata_path, 'rb') as f:
|
|
with open(self.metadata_path, 'rb') as f:
|
|
|
self.metadata = pickle.load(f)
|
|
self.metadata = pickle.load(f)
|
|
|
print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
|
|
print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
|
|
|
|
|
+ self._update_index_mtime()
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
print(f"⚠️ 加载索引失败: {e},将初始化新索引")
|
|
print(f"⚠️ 加载索引失败: {e},将初始化新索引")
|
|
|
self._init_new_index()
|
|
self._init_new_index()
|
|
@@ -109,14 +128,105 @@ class QuestionDuplicateChecker:
|
|
|
# 使用内积索引,配合权重拼接向量
|
|
# 使用内积索引,配合权重拼接向量
|
|
|
self.index = faiss.IndexFlatIP(self.dimension)
|
|
self.index = faiss.IndexFlatIP(self.dimension)
|
|
|
self.metadata = []
|
|
self.metadata = []
|
|
|
|
|
+ self.index_mtime = None
|
|
|
|
|
+ self.metadata_mtime = None
|
|
|
|
|
+ self.reload_failures = 0
|
|
|
|
|
+ self.next_reload_time = 0.0
|
|
|
print("✓ 已初始化新的FAISS索引")
|
|
print("✓ 已初始化新的FAISS索引")
|
|
|
|
|
|
|
|
|
|
+ def _update_index_mtime(self):
|
|
|
|
|
+ """更新索引文件的最后修改时间记录"""
|
|
|
|
|
+ self.index_mtime = os.path.getmtime(self.index_path) if os.path.exists(self.index_path) else None
|
|
|
|
|
+ self.metadata_mtime = os.path.getmtime(self.metadata_path) if os.path.exists(self.metadata_path) else None
|
|
|
|
|
+
|
|
|
|
|
+ def ensure_index_loaded(self, force: bool = False) -> bool:
|
|
|
|
|
+ """确保索引已从磁盘加载(多进程下避免空索引常驻)"""
|
|
|
|
|
+ # 读取失败后做退避,避免频繁重载导致磁盘抖动
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ if not force and now < self.next_reload_time:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ index_exists = os.path.exists(self.index_path) and os.path.exists(self.metadata_path)
|
|
|
|
|
+ if not index_exists:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ need_reload = force
|
|
|
|
|
+ if not need_reload:
|
|
|
|
|
+ # 空索引或元数据为空时,优先尝试从磁盘加载
|
|
|
|
|
+ if self.index is None or self.index.ntotal == 0 or not self.metadata:
|
|
|
|
|
+ need_reload = True
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 文件有更新则重载
|
|
|
|
|
+ current_index_mtime = os.path.getmtime(self.index_path)
|
|
|
|
|
+ current_metadata_mtime = os.path.getmtime(self.metadata_path)
|
|
|
|
|
+ if (self.index_mtime is None or self.metadata_mtime is None or
|
|
|
|
|
+ current_index_mtime > self.index_mtime or current_metadata_mtime > self.metadata_mtime):
|
|
|
|
|
+ need_reload = True
|
|
|
|
|
+
|
|
|
|
|
+ if not need_reload:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ with self.reload_lock:
|
|
|
|
|
+ # 同步进行中时,避免在索引尚未写完时反复重载
|
|
|
|
|
+ if self.sync_in_progress and not force:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 双重检查,避免重复加载
|
|
|
|
|
+ if not force and self.index is not None and self.index.ntotal > 0 and self.metadata:
|
|
|
|
|
+ current_index_mtime = os.path.getmtime(self.index_path)
|
|
|
|
|
+ current_metadata_mtime = os.path.getmtime(self.metadata_path)
|
|
|
|
|
+ if (self.index_mtime is not None and self.metadata_mtime is not None and
|
|
|
|
|
+ current_index_mtime <= self.index_mtime and current_metadata_mtime <= self.metadata_mtime):
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.index = faiss.read_index(self.index_path)
|
|
|
|
|
+ with open(self.metadata_path, 'rb') as f:
|
|
|
|
|
+ self.metadata = pickle.load(f)
|
|
|
|
|
+ self._update_index_mtime()
|
|
|
|
|
+ self.reload_failures = 0
|
|
|
|
|
+ self.next_reload_time = 0.0
|
|
|
|
|
+ print(f"↻ 已重新加载索引,包含 {len(self.metadata)} 道题目")
|
|
|
|
|
+ return True
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"⚠️ 重新加载索引失败: {e}")
|
|
|
|
|
+ # 指数退避,最多 60 秒
|
|
|
|
|
+ self.reload_failures += 1
|
|
|
|
|
+ backoff = min(60.0, 2 ** self.reload_failures)
|
|
|
|
|
+ self.next_reload_time = time.time() + backoff
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
def save_index(self):
|
|
def save_index(self):
|
|
|
"""保存索引和元数据"""
|
|
"""保存索引和元数据"""
|
|
|
- faiss.write_index(self.index, self.index_path)
|
|
|
|
|
- with open(self.metadata_path, 'wb') as f:
|
|
|
|
|
- pickle.dump(self.metadata, f)
|
|
|
|
|
- print(f"✓ 索引和元数据已保存到 {self.index_path}")
|
|
|
|
|
|
|
+ # 先写入临时文件,再原子替换,避免读写冲突导致索引损坏
|
|
|
|
|
+ index_dir = os.path.dirname(self.index_path)
|
|
|
|
|
+ meta_dir = os.path.dirname(self.metadata_path)
|
|
|
|
|
+ os.makedirs(index_dir, exist_ok=True)
|
|
|
|
|
+ os.makedirs(meta_dir, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ tmp_index_fd, tmp_index_path = tempfile.mkstemp(prefix=".questions_tem.index.", dir=index_dir)
|
|
|
|
|
+ tmp_meta_fd, tmp_meta_path = tempfile.mkstemp(prefix=".questions_tem_metadata.pkl.", dir=meta_dir)
|
|
|
|
|
+ try:
|
|
|
|
|
+ os.close(tmp_index_fd)
|
|
|
|
|
+ os.close(tmp_meta_fd)
|
|
|
|
|
+
|
|
|
|
|
+ faiss.write_index(self.index, tmp_index_path)
|
|
|
|
|
+ with open(tmp_meta_path, 'wb') as f:
|
|
|
|
|
+ pickle.dump(self.metadata, f)
|
|
|
|
|
+ f.flush()
|
|
|
|
|
+ os.fsync(f.fileno())
|
|
|
|
|
+
|
|
|
|
|
+ os.replace(tmp_index_path, self.index_path)
|
|
|
|
|
+ os.replace(tmp_meta_path, self.metadata_path)
|
|
|
|
|
+
|
|
|
|
|
+ self._update_index_mtime()
|
|
|
|
|
+ print(f"✓ 索引和元数据已保存到 {self.index_path}")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ # 清理临时文件(若 replace 成功则不存在)
|
|
|
|
|
+ if os.path.exists(tmp_index_path):
|
|
|
|
|
+ os.remove(tmp_index_path)
|
|
|
|
|
+ if os.path.exists(tmp_meta_path):
|
|
|
|
|
+ os.remove(tmp_meta_path)
|
|
|
|
|
|
|
|
def get_weighted_embedding(self, question: Dict) -> np.ndarray:
|
|
def get_weighted_embedding(self, question: Dict) -> np.ndarray:
|
|
|
"""
|
|
"""
|
|
@@ -356,6 +466,7 @@ class QuestionDuplicateChecker:
|
|
|
"""
|
|
"""
|
|
|
查重主逻辑 (未入库模式:不更新数据库,不自动入库)
|
|
查重主逻辑 (未入库模式:不更新数据库,不自动入库)
|
|
|
"""
|
|
"""
|
|
|
|
|
+ self.ensure_index_loaded()
|
|
|
# 1. 获取题目信息
|
|
# 1. 获取题目信息
|
|
|
question = self.fetch_question_from_db(question_id)
|
|
question = self.fetch_question_from_db(question_id)
|
|
|
if not question:
|
|
if not question:
|
|
@@ -430,6 +541,7 @@ class QuestionDuplicateChecker:
|
|
|
"""
|
|
"""
|
|
|
基于原始文本内容进行查重 (预检模式)
|
|
基于原始文本内容进行查重 (预检模式)
|
|
|
"""
|
|
"""
|
|
|
|
|
+ self.ensure_index_loaded()
|
|
|
# 1. 获取加权拼接向量
|
|
# 1. 获取加权拼接向量
|
|
|
mega_vector = self.get_weighted_embedding(question_data)
|
|
mega_vector = self.get_weighted_embedding(question_data)
|
|
|
if mega_vector is None:
|
|
if mega_vector is None:
|
|
@@ -478,6 +590,7 @@ class QuestionDuplicateChecker:
|
|
|
|
|
|
|
|
def get_question_data(self, question_id: int) -> Dict:
|
|
def get_question_data(self, question_id: int) -> Dict:
|
|
|
"""获取向量库中特定ID的数据及总数"""
|
|
"""获取向量库中特定ID的数据及总数"""
|
|
|
|
|
+ self.ensure_index_loaded()
|
|
|
total_count = self.index.ntotal if self.index else 0
|
|
total_count = self.index.ntotal if self.index else 0
|
|
|
|
|
|
|
|
target_metadata = None
|
|
target_metadata = None
|
|
@@ -503,8 +616,13 @@ class QuestionDuplicateChecker:
|
|
|
})
|
|
})
|
|
|
self.save_index()
|
|
self.save_index()
|
|
|
|
|
|
|
|
- def sync_all_from_db(self, batch_size=50, max_workers=5):
|
|
|
|
|
|
|
+ def sync_all_from_db(self, batch_size=50, max_workers=None):
|
|
|
"""同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
|
|
"""同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
|
|
|
|
|
+ if not self.sync_lock.acquire(blocking=False):
|
|
|
|
|
+ print("⏳ 同步正在进行中,已忽略重复请求")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ self.sync_in_progress = True
|
|
|
print("🔄 开始全量同步 (优化版 - 加权模式)...")
|
|
print("🔄 开始全量同步 (优化版 - 加权模式)...")
|
|
|
existing_ids = {m['id'] for m in self.metadata}
|
|
existing_ids = {m['id'] for m in self.metadata}
|
|
|
try:
|
|
try:
|
|
@@ -529,7 +647,7 @@ class QuestionDuplicateChecker:
|
|
|
|
|
|
|
|
if total_new == 0:
|
|
if total_new == 0:
|
|
|
print("✅ 已经是最新状态,无需同步。")
|
|
print("✅ 已经是最新状态,无需同步。")
|
|
|
- return
|
|
|
|
|
|
|
+ return True
|
|
|
|
|
|
|
|
print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
|
|
print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
|
|
|
|
|
|
|
@@ -540,8 +658,9 @@ class QuestionDuplicateChecker:
|
|
|
mega_vectors = self.get_weighted_embeddings_batch(chunk)
|
|
mega_vectors = self.get_weighted_embeddings_batch(chunk)
|
|
|
return chunk, mega_vectors
|
|
return chunk, mega_vectors
|
|
|
|
|
|
|
|
- # 使用线程池并发
|
|
|
|
|
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
|
|
|
+ # 使用线程池并发(小机器限制并发,避免 CPU 过载)
|
|
|
|
|
+ worker_count = max_workers or max(1, min(2, os.cpu_count() or 1))
|
|
|
|
|
+ with concurrent.futures.ThreadPoolExecutor(max_workers=worker_count) as executor:
|
|
|
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
|
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
|
|
|
|
|
|
|
count = 0
|
|
count = 0
|
|
@@ -565,13 +684,18 @@ class QuestionDuplicateChecker:
|
|
|
|
|
|
|
|
self.save_index()
|
|
self.save_index()
|
|
|
print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
|
|
print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
|
|
|
|
|
+ return True
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
print(f"❌ 同步失败: {e}")
|
|
print(f"❌ 同步失败: {e}")
|
|
|
import traceback
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
|
|
|
+ return False
|
|
|
finally:
|
|
finally:
|
|
|
if 'conn' in locals() and conn:
|
|
if 'conn' in locals() and conn:
|
|
|
conn.close()
|
|
conn.close()
|
|
|
|
|
+ self.sync_in_progress = False
|
|
|
|
|
+ if self.sync_lock.locked():
|
|
|
|
|
+ self.sync_lock.release()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
# 测试代码
|
|
# 测试代码
|