Bladeren bron

commit 修复并发加载向量索引问题

Hai Lin 2 dagen geleden
bovenliggende
commit
b7cb76fa66
4 gewijzigde bestanden met toevoegingen van 160 en 13 verwijderingen
  1. 1 1
      Dockerfile
  2. 24 2
      app.py
  3. 134 10
      duplicate_checker.py
  4. 1 0
      requirements.txt

+ 1 - 1
Dockerfile

@@ -28,4 +28,4 @@ EXPOSE 8888
 # --error-logfile -:错误日志输出到stdout
 # --log-level info:日志级别,方便排查问题
 # app:app:Flask应用入口(第一个app是app.py文件,第二个app是文件内的Flask实例)
-CMD ["gunicorn", "-w", "2", "-b", "0.0.0.0:8888", "--timeout", "300", "--graceful-timeout", "300", "--worker-class", "gevent", "--worker-connections", "1000", "--access-logfile", "-", "--error-logfile", "-", "--log-level", "info", "app:app"]
+CMD ["gunicorn", "-w", "1", "-b", "0.0.0.0:8888", "--timeout", "300", "--graceful-timeout", "300", "--worker-class", "gevent", "--worker-connections", "1000", "--access-logfile", "-", "--error-logfile", "-", "--log-level", "info", "app:app"]

+ 24 - 2
app.py

@@ -29,6 +29,8 @@ def check_duplicate():
     if not question_data["stem"]:
         return jsonify({"code": -1, "message": "stem is required"}), 400
 
+    # 确保索引已加载(多进程下避免空索引)
+    checker.ensure_index_loaded()
     # 执行基于内容的查重
     result = checker.check_duplicate_by_content(question_data)
 
@@ -62,8 +64,11 @@ def sync_index():
     """手动触发全量同步接口"""
     print("🔄 收到同步索引请求")
     try:
-        checker.sync_all_from_db()
-        return jsonify({"code": 0, "result": "Sync completed"})
+        checker.ensure_index_loaded()
+        started = checker.sync_all_from_db()
+        if started:
+            return jsonify({"code": 0, "result": "Sync completed"})
+        return jsonify({"code": 0, "result": "Sync already running"})
     except Exception as e:
         return jsonify({"code": -1, "message": str(e)}), 500
 
@@ -85,6 +90,7 @@ def confirm_repeat():
         return jsonify({"code": -1, "message": "Missing questionId or isRepeat"}), 400
 
     try:
+        checker.ensure_index_loaded()
         success = checker.confirm_repeat(int(question_id), int(is_repeat))
         if success:
             return jsonify({"code": 0, "result": "ok"})
@@ -104,6 +110,7 @@ def get_question_info():
         return jsonify({"code": -1, "message": "Missing questionId"}), 400
     
     try:
+        checker.ensure_index_loaded()
         result = checker.get_question_data(int(question_id))
         return jsonify({
             "code": 0,
@@ -114,6 +121,21 @@ def get_question_info():
     except Exception as e:
         return jsonify({"code": -1, "message": str(e)}), 500
 
+@app.route('/api/index_info', methods=['GET'])
+def get_index_info():
+    """查看当前索引文件路径及条数"""
+    checker.ensure_index_loaded()
+    index_count = int(checker.index.ntotal) if checker.index else 0
+    return jsonify({
+        "code": 0,
+        "result": {
+            "index_path": checker.index_path,
+            "metadata_path": checker.metadata_path,
+            "index_count": index_count,
+            "metadata_count": len(checker.metadata)
+        }
+    })
+
 if __name__ == '__main__':
     # 启动服务,默认 5000 端口
     app.run(host='0.0.0.0', port=8888, debug=False)

+ 134 - 10
duplicate_checker.py

@@ -5,6 +5,9 @@ import json
 import os
 import pickle
 import re
+import threading
+import time
+import tempfile
 import numpy as np
 import faiss
 import pymysql
@@ -28,10 +31,19 @@ class QuestionDuplicateChecker:
             api_key=OPENAI_API_KEY,
             base_url=OPENAI_BASE_URL
         )
-        self.index_path = index_path
-        self.metadata_path = metadata_path
+        base_dir = os.path.dirname(os.path.abspath(__file__))
+        # 统一使用绝对路径,避免不同工作目录导致索引文件混乱
+        self.index_path = index_path if os.path.isabs(index_path) else os.path.join(base_dir, index_path)
+        self.metadata_path = metadata_path if os.path.isabs(metadata_path) else os.path.join(base_dir, metadata_path)
         self.index = None
         self.metadata = [] # 存储题目ID和文本,以便回显
+        self.sync_lock = threading.Lock()
+        self.reload_lock = threading.Lock()
+        self.index_mtime = None
+        self.metadata_mtime = None
+        self.sync_in_progress = False
+        self.reload_failures = 0
+        self.next_reload_time = 0.0
         
         # 权重设置
         self.weights = {
@@ -42,6 +54,12 @@ class QuestionDuplicateChecker:
         }
         # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
         self.dimension = 3072 * 4
+        # 限制 FAISS/OpenMP 线程,避免小机器 CPU 过载
+        faiss_threads = max(1, min(2, os.cpu_count() or 1))
+        try:
+            faiss.omp_set_num_threads(faiss_threads)
+        except Exception:
+            pass
         
         self._load_index()
 
@@ -98,6 +116,7 @@ class QuestionDuplicateChecker:
                     with open(self.metadata_path, 'rb') as f:
                         self.metadata = pickle.load(f)
                     print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
+                    self._update_index_mtime()
             except Exception as e:
                 print(f"⚠️ 加载索引失败: {e},将初始化新索引")
                 self._init_new_index()
@@ -109,14 +128,105 @@ class QuestionDuplicateChecker:
         # 使用内积索引,配合权重拼接向量
         self.index = faiss.IndexFlatIP(self.dimension)
         self.metadata = []
+        self.index_mtime = None
+        self.metadata_mtime = None
+        self.reload_failures = 0
+        self.next_reload_time = 0.0
         print("✓ 已初始化新的FAISS索引")
 
+    def _update_index_mtime(self):
+        """更新索引文件的最后修改时间记录"""
+        self.index_mtime = os.path.getmtime(self.index_path) if os.path.exists(self.index_path) else None
+        self.metadata_mtime = os.path.getmtime(self.metadata_path) if os.path.exists(self.metadata_path) else None
+
+    def ensure_index_loaded(self, force: bool = False) -> bool:
+        """确保索引已从磁盘加载(多进程下避免空索引常驻)"""
+        # 读取失败后做退避,避免频繁重载导致磁盘抖动
+        now = time.time()
+        if not force and now < self.next_reload_time:
+            return False
+
+        index_exists = os.path.exists(self.index_path) and os.path.exists(self.metadata_path)
+        if not index_exists:
+            return False
+
+        need_reload = force
+        if not need_reload:
+            # 空索引或元数据为空时,优先尝试从磁盘加载
+            if self.index is None or self.index.ntotal == 0 or not self.metadata:
+                need_reload = True
+            else:
+                # 文件有更新则重载
+                current_index_mtime = os.path.getmtime(self.index_path)
+                current_metadata_mtime = os.path.getmtime(self.metadata_path)
+                if (self.index_mtime is None or self.metadata_mtime is None or
+                        current_index_mtime > self.index_mtime or current_metadata_mtime > self.metadata_mtime):
+                    need_reload = True
+
+        if not need_reload:
+            return False
+
+        with self.reload_lock:
+            # 同步进行中时,避免在索引尚未写完时反复重载
+            if self.sync_in_progress and not force:
+                return False
+
+            # 双重检查,避免重复加载
+            if not force and self.index is not None and self.index.ntotal > 0 and self.metadata:
+                current_index_mtime = os.path.getmtime(self.index_path)
+                current_metadata_mtime = os.path.getmtime(self.metadata_path)
+                if (self.index_mtime is not None and self.metadata_mtime is not None and
+                        current_index_mtime <= self.index_mtime and current_metadata_mtime <= self.metadata_mtime):
+                    return False
+
+            try:
+                self.index = faiss.read_index(self.index_path)
+                with open(self.metadata_path, 'rb') as f:
+                    self.metadata = pickle.load(f)
+                self._update_index_mtime()
+                self.reload_failures = 0
+                self.next_reload_time = 0.0
+                print(f"↻ 已重新加载索引,包含 {len(self.metadata)} 道题目")
+                return True
+            except Exception as e:
+                print(f"⚠️ 重新加载索引失败: {e}")
+                # 指数退避,最多 60 秒
+                self.reload_failures += 1
+                backoff = min(60.0, 2 ** self.reload_failures)
+                self.next_reload_time = time.time() + backoff
+                return False
+
     def save_index(self):
         """保存索引和元数据"""
-        faiss.write_index(self.index, self.index_path)
-        with open(self.metadata_path, 'wb') as f:
-            pickle.dump(self.metadata, f)
-        print(f"✓ 索引和元数据已保存到 {self.index_path}")
+        # 先写入临时文件,再原子替换,避免读写冲突导致索引损坏
+        index_dir = os.path.dirname(self.index_path)
+        meta_dir = os.path.dirname(self.metadata_path)
+        os.makedirs(index_dir, exist_ok=True)
+        os.makedirs(meta_dir, exist_ok=True)
+
+        tmp_index_fd, tmp_index_path = tempfile.mkstemp(prefix=".questions_tem.index.", dir=index_dir)
+        tmp_meta_fd, tmp_meta_path = tempfile.mkstemp(prefix=".questions_tem_metadata.pkl.", dir=meta_dir)
+        try:
+            os.close(tmp_index_fd)
+            os.close(tmp_meta_fd)
+
+            faiss.write_index(self.index, tmp_index_path)
+            with open(tmp_meta_path, 'wb') as f:
+                pickle.dump(self.metadata, f)
+                f.flush()
+                os.fsync(f.fileno())
+
+            os.replace(tmp_index_path, self.index_path)
+            os.replace(tmp_meta_path, self.metadata_path)
+
+            self._update_index_mtime()
+            print(f"✓ 索引和元数据已保存到 {self.index_path}")
+        finally:
+            # 清理临时文件(若 replace 成功则不存在)
+            if os.path.exists(tmp_index_path):
+                os.remove(tmp_index_path)
+            if os.path.exists(tmp_meta_path):
+                os.remove(tmp_meta_path)
 
     def get_weighted_embedding(self, question: Dict) -> np.ndarray:
         """
@@ -356,6 +466,7 @@ class QuestionDuplicateChecker:
         """
         查重主逻辑 (未入库模式:不更新数据库,不自动入库)
         """
+        self.ensure_index_loaded()
         # 1. 获取题目信息
         question = self.fetch_question_from_db(question_id)
         if not question:
@@ -430,6 +541,7 @@ class QuestionDuplicateChecker:
         """
         基于原始文本内容进行查重 (预检模式)
         """
+        self.ensure_index_loaded()
         # 1. 获取加权拼接向量
         mega_vector = self.get_weighted_embedding(question_data)
         if mega_vector is None:
@@ -478,6 +590,7 @@ class QuestionDuplicateChecker:
 
     def get_question_data(self, question_id: int) -> Dict:
         """获取向量库中特定ID的数据及总数"""
+        self.ensure_index_loaded()
         total_count = self.index.ntotal if self.index else 0
         
         target_metadata = None
@@ -503,8 +616,13 @@ class QuestionDuplicateChecker:
         })
         self.save_index()
 
-    def sync_all_from_db(self, batch_size=50, max_workers=5):
+    def sync_all_from_db(self, batch_size=50, max_workers=None):
         """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
+        if not self.sync_lock.acquire(blocking=False):
+            print("⏳ 同步正在进行中,已忽略重复请求")
+            return False
+
+        self.sync_in_progress = True
         print("🔄 开始全量同步 (优化版 - 加权模式)...")
         existing_ids = {m['id'] for m in self.metadata}
         try:
@@ -529,7 +647,7 @@ class QuestionDuplicateChecker:
             
             if total_new == 0:
                 print("✅ 已经是最新状态,无需同步。")
-                return
+                return True
 
             print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
 
@@ -540,8 +658,9 @@ class QuestionDuplicateChecker:
                 mega_vectors = self.get_weighted_embeddings_batch(chunk)
                 return chunk, mega_vectors
 
-            # 使用线程池并发
-            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+            # 使用线程池并发(小机器限制并发,避免 CPU 过载)
+            worker_count = max_workers or max(1, min(2, os.cpu_count() or 1))
+            with concurrent.futures.ThreadPoolExecutor(max_workers=worker_count) as executor:
                 future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
                 
                 count = 0
@@ -565,13 +684,18 @@ class QuestionDuplicateChecker:
             
             self.save_index()
             print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
+            return True
         except Exception as e:
             print(f"❌ 同步失败: {e}")
             import traceback
             traceback.print_exc()
+            return False
         finally:
             if 'conn' in locals() and conn:
                 conn.close()
+            self.sync_in_progress = False
+            if self.sync_lock.locked():
+                self.sync_lock.release()
 
 if __name__ == "__main__":
     # 测试代码

+ 1 - 0
requirements.txt

@@ -6,3 +6,4 @@ faiss-cpu
 pymysql
 python-dotenv
 cryptography
+gevent>=24.10.1