Bläddra i källkod

为校准分析器添加历史快照功能,支持快照周期和最大快照数量的配置;优化持久化逻辑以避免重复写入,并增强输出信息以显示快照写入情况。

admin 1 månad sedan
förälder
incheckning
306a9ee66a
2 ändrade filer med 274 tillägg och 2 borttagningar
  1. 30 0
      scripts/run_analyzers.py
  2. 244 2
      src/databank/analytics/calibration.py

+ 30 - 0
scripts/run_analyzers.py

@@ -182,6 +182,23 @@ def main() -> None:
                     "verbose": os.getenv("DATABANK_CAL_VERBOSE", "1").strip()
                     not in {"0", "false", "False"},
                 }
+                # History snapshots for calibration (optional)
+                cal_history = (
+                    os.getenv("DATABANK_CAL_HISTORY", "snapshots").strip().lower()
+                )
+                if cal_history in {"none", "snapshots", "both"}:
+                    cal_kwargs["history"] = cal_history
+                cal_kwargs["snapshot_every"] = _env_int(
+                    "DATABANK_CAL_SNAPSHOT_EVERY", 1000
+                )
+                pdays = _env_int_optional("DATABANK_CAL_PERIOD_DAYS")
+                if pdays is not None:
+                    cal_kwargs["period_days"] = pdays
+                else:
+                    cal_kwargs["period_days"] = 30
+                msnaps = _env_int_optional("DATABANK_CAL_MAX_SNAPSHOTS")
+                if msnaps is not None:
+                    cal_kwargs["max_snapshots"] = msnaps
                 since_ts = _env_int_optional("DATABANK_CAL_SINCE_TS")
                 until_ts = _env_int_optional("DATABANK_CAL_UNTIL_TS")
                 if since_ts is not None:
@@ -217,6 +234,19 @@ def main() -> None:
                             )
                         except (ValueError, TypeError, KeyError):
                             print(f"      - {gk}: (metrics unavailable)")
+                # Optional: print snapshot writes if present
+                if isinstance(result, dict):
+                    sm = result.get("snap_metrics_written")
+                    sb = result.get("snap_bins_written")
+                    if (isinstance(sm, int) and sm > 0) or (
+                        isinstance(sb, int) and sb > 0
+                    ):
+                        print(
+                            (
+                                f"    CAL snapshots persisted: metrics={sm or 0}"
+                                f" bins={sb or 0}"
+                            )
+                        )
             else:
                 result = analyzer.compute(transformed, db=db, persist=True)
             analyzer.finalize(result)

+ 244 - 2
src/databank/analytics/calibration.py

@@ -46,6 +46,21 @@ class CalibrationAnalyzer(AnalyticsBase):
         persist: bool = bool(kwargs.get("persist", True))
         verbose: bool = bool(kwargs.get("verbose", True))
         db: Optional[BaseDB] = kwargs.get("db")
+        # Optional history snapshots (default: enabled, 30-day period)
+        history: str = str(kwargs.get("history", "snapshots")).lower()
+        snapshot_every: int = int(kwargs.get("snapshot_every", 200))
+        period_days_raw = kwargs.get("period_days")
+        max_snapshots_raw = kwargs.get("max_snapshots")
+        try:
+            period_days = int(period_days_raw) if period_days_raw is not None else 30
+        except (ValueError, TypeError):
+            period_days = 30
+        try:
+            max_snapshots = (
+                int(max_snapshots_raw) if max_snapshots_raw is not None else None
+            )
+        except (ValueError, TypeError):
+            max_snapshots = None
         # Optional filters
         since_ts_raw = kwargs.get("since_ts")
         until_ts_raw = kwargs.get("until_ts")
@@ -151,6 +166,7 @@ class CalibrationAnalyzer(AnalyticsBase):
                 {
                     "p": (ph, pd, pa),
                     "outcome": outcome,
+                    "ts": int(cut_ts) if cut_ts is not None else 0,
                 }
             )
             count_in += 1
@@ -166,6 +182,62 @@ class CalibrationAnalyzer(AnalyticsBase):
         results: Dict[str, dict] = {}
         num_bins = max(2, int(bins))
 
+        # Optional skip caches (avoid duplicate writes)
+        skip_if_exists: bool = bool(kwargs.get("skip_if_exists", False))
+        exist_groups_metrics: set[str] = set()
+        exist_hist_keys: set[tuple[str, int]] = set()
+        if skip_if_exists and persist and db:
+            try:
+                db_any = cast(Any, db)
+                # existing overall metrics per group for this source+bins
+                for doc in (
+                    db_any.find(
+                        "calibration_metrics", projection={"_id": 0}, limit=None
+                    )
+                    or []
+                ):
+                    if not isinstance(doc, dict):
+                        continue
+                    data_blob = doc.get("data")
+                    if not isinstance(data_blob, dict):
+                        continue
+                    if (
+                        data_blob.get("source") == source_kind
+                        and int(data_blob.get("bins", -1)) == num_bins
+                    ):
+                        g = data_blob.get("group")
+                        if isinstance(g, str):
+                            exist_groups_metrics.add(g)
+                # existing snapshot metrics (group, ts) for this source+bins
+                for doc in (
+                    db_any.find(
+                        "calibration_metrics_history", projection={"_id": 0}, limit=None
+                    )
+                    or []
+                ):
+                    if not isinstance(doc, dict):
+                        continue
+                    data_blob = doc.get("data")
+                    if not isinstance(data_blob, dict):
+                        continue
+                    if (
+                        data_blob.get("source") == source_kind
+                        and int(data_blob.get("bins", -1)) == num_bins
+                    ):
+                        g = data_blob.get("group")
+                        ts0 = data_blob.get("cutoff_ts")
+                        if g is None or ts0 is None:
+                            continue
+                        try:
+                            ts_i = int(ts0)
+                        except (ValueError, TypeError):
+                            continue
+                        if isinstance(g, str):
+                            exist_hist_keys.add((g, ts_i))
+            except (RuntimeError, ValueError, TypeError):
+                # diagnostics only; if listing fails we just proceed without skipping
+                pass
+
         def _bin_index(p: float) -> int:
             p = min(1.0 - 1e-12, max(0.0, float(p)))
             return min(num_bins - 1, int(p * num_bins))
@@ -293,7 +365,7 @@ class CalibrationAnalyzer(AnalyticsBase):
             }
             results[gk] = metrics
 
-            if persist and db:
+            if persist and db and not (skip_if_exists and gk in exist_groups_metrics):
                 docs.append(
                     Document(
                         id=f"{gk}:{source_kind}:metrics:{num_bins}",
@@ -313,13 +385,181 @@ class CalibrationAnalyzer(AnalyticsBase):
                         )
                     )
 
+        # Optional: history snapshots (cumulative over time)
+        snap_metrics_written = 0
+        snap_bins_written = 0
+        if history in {"snapshots", "both"}:
+            for gk, rows in groups.items():
+                if not rows:
+                    continue
+                # sort by time ascending
+                rows_sorted = sorted(rows, key=lambda r: r.get("ts", 0))
+                # running accumulators
+                conf_sum = [0.0 for _ in range(num_bins)]
+                acc_sum = [0.0 for _ in range(num_bins)]
+                cnt = [0 for _ in range(num_bins)]
+                cls_bins = {
+                    "H": {
+                        "p_sum": [0.0] * num_bins,
+                        "y_sum": [0.0] * num_bins,
+                        "cnt": [0] * num_bins,
+                    },
+                    "D": {
+                        "p_sum": [0.0] * num_bins,
+                        "y_sum": [0.0] * num_bins,
+                        "cnt": [0] * num_bins,
+                    },
+                    "A": {
+                        "p_sum": [0.0] * num_bins,
+                        "y_sum": [0.0] * num_bins,
+                        "cnt": [0] * num_bins,
+                    },
+                }
+                brier_sum = 0.0
+                logloss_sum = 0.0
+                last_snap_ts = rows_sorted[0].get("ts", 0)
+                period_secs = max(0, int(period_days)) * 86400
+                snaps = 0
+                n = 0
+                for i, r in enumerate(rows_sorted):
+                    ph, pd, pa = r["p"]
+                    oh, od, oa = _one_hot(r["outcome"])
+                    # update accumulators
+                    conf = max(ph, pd, pa)
+                    pred_label = (
+                        "H" if ph >= pd and ph >= pa else ("D" if pd >= pa else "A")
+                    )
+                    correct = (
+                        1.0
+                        if (
+                            (pred_label == "H" and oh == 1)
+                            or (pred_label == "D" and od == 1)
+                            or (pred_label == "A" and oa == 1)
+                        )
+                        else 0.0
+                    )
+                    bi = _bin_index(conf)
+                    conf_sum[bi] += conf
+                    acc_sum[bi] += correct
+                    cnt[bi] += 1
+                    brier_sum += (ph - oh) ** 2 + (pd - od) ** 2 + (pa - oa) ** 2
+                    p_true = max(eps, ph if oh else (pd if od else pa))
+                    logloss_sum += -math.log(p_true)
+                    for lab, p_val, y_val in (
+                        ("H", ph, oh),
+                        ("D", pd, od),
+                        ("A", pa, oa),
+                    ):
+                        bi2 = _bin_index(p_val)
+                        cls_bins[lab]["p_sum"][bi2] += p_val
+                        cls_bins[lab]["y_sum"][bi2] += y_val
+                        cls_bins[lab]["cnt"][bi2] += 1
+                    n += 1
+
+                    # snapshot trigger
+                    do_count = snapshot_every > 0 and (n % snapshot_every == 0)
+                    ts_now = int(r.get("ts", 0))
+                    do_period = period_secs > 0 and (
+                        ts_now - last_snap_ts >= period_secs
+                    )
+                    limit_ok = max_snapshots is None or snaps < max_snapshots
+                    if (do_count or do_period) and limit_ok:
+                        # compute ece
+                        ece = 0.0
+                        for bi3 in range(num_bins):
+                            if cnt[bi3] == 0:
+                                continue
+                            conf_avg = conf_sum[bi3] / cnt[bi3]
+                            acc_avg = acc_sum[bi3] / cnt[bi3]
+                            w = cnt[bi3] / max(1.0, float(n))
+                            ece += w * abs(acc_avg - conf_avg)
+                        # build per-class bins
+                        bins_out: List[dict] = []
+                        for lab in ("H", "D", "A"):
+                            p_sum = cls_bins[lab]["p_sum"]
+                            y_sum = cls_bins[lab]["y_sum"]
+                            c_arr = cls_bins[lab]["cnt"]
+                            for j in range(num_bins):
+                                cval = c_arr[j]
+                                if cval < min_per_bin:
+                                    continue
+                                p_avg = p_sum[j] / cval
+                                y_avg = y_sum[j] / cval
+                                bins_out.append(
+                                    {
+                                        "group": gk,
+                                        "class": lab,
+                                        "bin": j,
+                                        "bins": num_bins,
+                                        "p_low": j / num_bins,
+                                        "p_high": (j + 1) / num_bins,
+                                        "avg_p": p_avg,
+                                        "emp_rate": y_avg,
+                                        "count": cval,
+                                        "source": source_kind,
+                                        "cutoff_ts": ts_now,
+                                    }
+                                )
+
+                        # write snapshot docs
+                        if (
+                            persist
+                            and db
+                            and not (skip_if_exists and (gk, ts_now) in exist_hist_keys)
+                        ):
+                            snap_id = f"{gk}:{source_kind}:snap:{ts_now}:{num_bins}"
+                            docs.append(
+                                Document(
+                                    id=snap_id,
+                                    kind="calibration_metrics_history",
+                                    data={
+                                        "group": gk,
+                                        "n": int(n),
+                                        "brier": brier_sum / max(1.0, float(n)),
+                                        "logloss": logloss_sum / max(1.0, float(n)),
+                                        "ece": ece,
+                                        "bins": num_bins,
+                                        "min_per_bin": int(min_per_bin),
+                                        "source": source_kind,
+                                        "cutoff_ts": ts_now,
+                                    },
+                                )
+                            )
+                            snap_metrics_written += 1
+                            for b in bins_out:
+                                cls_label = b["class"]
+                                bin_idx = b["bin"]
+                                doc_id = (
+                                    f"{gk}:{source_kind}:snap_bins:{cls_label}:"
+                                    f"{bin_idx}:{ts_now}:{num_bins}"
+                                )
+                                docs.append(
+                                    Document(
+                                        id=doc_id,
+                                        kind="calibration_bins_history",
+                                        data=b,
+                                    )
+                                )
+                                snap_bins_written += 1
+                        snaps += 1
+                        last_snap_ts = ts_now
+
         if persist and db and docs:
             db.insert_many(docs)
             if verbose:
                 metrics_cnt = sum(1 for d in docs if d.kind == "calibration_metrics")
                 bins_cnt = sum(1 for d in docs if d.kind == "calibration_bins")
+                hist_m_cnt = sum(
+                    1 for d in docs if d.kind == "calibration_metrics_history"
+                )
+                hist_b_cnt = sum(
+                    1 for d in docs if d.kind == "calibration_bins_history"
+                )
                 print(
-                    (f"[CAL] Persisted metrics={metrics_cnt} " f"bins={bins_cnt}"),
+                    (
+                        f"[CAL] Persisted metrics={metrics_cnt} bins={bins_cnt} "
+                        f"snap_metrics={hist_m_cnt} snap_bins={hist_b_cnt}"
+                    ),
                     flush=True,
                 )
 
@@ -327,5 +567,7 @@ class CalibrationAnalyzer(AnalyticsBase):
             "groups": list(results.keys()),
             "metrics": results,
             "persisted": len(docs) if (persist and db) else 0,
+            "snap_metrics_written": snap_metrics_written,
+            "snap_bins_written": snap_bins_written,
             "source": source_kind,
         }