|
|
@@ -46,6 +46,21 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
persist: bool = bool(kwargs.get("persist", True))
|
|
|
verbose: bool = bool(kwargs.get("verbose", True))
|
|
|
db: Optional[BaseDB] = kwargs.get("db")
|
|
|
+ # Optional history snapshots (default: enabled, 30-day period)
|
|
|
+ history: str = str(kwargs.get("history", "snapshots")).lower()
|
|
|
+ snapshot_every: int = int(kwargs.get("snapshot_every", 200))
|
|
|
+ period_days_raw = kwargs.get("period_days")
|
|
|
+ max_snapshots_raw = kwargs.get("max_snapshots")
|
|
|
+ try:
|
|
|
+ period_days = int(period_days_raw) if period_days_raw is not None else 30
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ period_days = 30
|
|
|
+ try:
|
|
|
+ max_snapshots = (
|
|
|
+ int(max_snapshots_raw) if max_snapshots_raw is not None else None
|
|
|
+ )
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ max_snapshots = None
|
|
|
# Optional filters
|
|
|
since_ts_raw = kwargs.get("since_ts")
|
|
|
until_ts_raw = kwargs.get("until_ts")
|
|
|
@@ -151,6 +166,7 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
{
|
|
|
"p": (ph, pd, pa),
|
|
|
"outcome": outcome,
|
|
|
+ "ts": int(cut_ts) if cut_ts is not None else 0,
|
|
|
}
|
|
|
)
|
|
|
count_in += 1
|
|
|
@@ -166,6 +182,62 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
results: Dict[str, dict] = {}
|
|
|
num_bins = max(2, int(bins))
|
|
|
|
|
|
+ # Optional skip caches (avoid duplicate writes)
|
|
|
+ skip_if_exists: bool = bool(kwargs.get("skip_if_exists", False))
|
|
|
+ exist_groups_metrics: set[str] = set()
|
|
|
+ exist_hist_keys: set[tuple[str, int]] = set()
|
|
|
+ if skip_if_exists and persist and db:
|
|
|
+ try:
|
|
|
+ db_any = cast(Any, db)
|
|
|
+ # existing overall metrics per group for this source+bins
|
|
|
+ for doc in (
|
|
|
+ db_any.find(
|
|
|
+ "calibration_metrics", projection={"_id": 0}, limit=None
|
|
|
+ )
|
|
|
+ or []
|
|
|
+ ):
|
|
|
+ if not isinstance(doc, dict):
|
|
|
+ continue
|
|
|
+ data_blob = doc.get("data")
|
|
|
+ if not isinstance(data_blob, dict):
|
|
|
+ continue
|
|
|
+ if (
|
|
|
+ data_blob.get("source") == source_kind
|
|
|
+ and int(data_blob.get("bins", -1)) == num_bins
|
|
|
+ ):
|
|
|
+ g = data_blob.get("group")
|
|
|
+ if isinstance(g, str):
|
|
|
+ exist_groups_metrics.add(g)
|
|
|
+ # existing snapshot metrics (group, ts) for this source+bins
|
|
|
+ for doc in (
|
|
|
+ db_any.find(
|
|
|
+ "calibration_metrics_history", projection={"_id": 0}, limit=None
|
|
|
+ )
|
|
|
+ or []
|
|
|
+ ):
|
|
|
+ if not isinstance(doc, dict):
|
|
|
+ continue
|
|
|
+ data_blob = doc.get("data")
|
|
|
+ if not isinstance(data_blob, dict):
|
|
|
+ continue
|
|
|
+ if (
|
|
|
+ data_blob.get("source") == source_kind
|
|
|
+ and int(data_blob.get("bins", -1)) == num_bins
|
|
|
+ ):
|
|
|
+ g = data_blob.get("group")
|
|
|
+ ts0 = data_blob.get("cutoff_ts")
|
|
|
+ if g is None or ts0 is None:
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ ts_i = int(ts0)
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ continue
|
|
|
+ if isinstance(g, str):
|
|
|
+ exist_hist_keys.add((g, ts_i))
|
|
|
+ except (RuntimeError, ValueError, TypeError):
|
|
|
+ # diagnostics only; if listing fails we just proceed without skipping
|
|
|
+ pass
|
|
|
+
|
|
|
def _bin_index(p: float) -> int:
|
|
|
p = min(1.0 - 1e-12, max(0.0, float(p)))
|
|
|
return min(num_bins - 1, int(p * num_bins))
|
|
|
@@ -293,7 +365,7 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
}
|
|
|
results[gk] = metrics
|
|
|
|
|
|
- if persist and db:
|
|
|
+ if persist and db and not (skip_if_exists and gk in exist_groups_metrics):
|
|
|
docs.append(
|
|
|
Document(
|
|
|
id=f"{gk}:{source_kind}:metrics:{num_bins}",
|
|
|
@@ -313,13 +385,181 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
+ # Optional: history snapshots (cumulative over time)
|
|
|
+ snap_metrics_written = 0
|
|
|
+ snap_bins_written = 0
|
|
|
+ if history in {"snapshots", "both"}:
|
|
|
+ for gk, rows in groups.items():
|
|
|
+ if not rows:
|
|
|
+ continue
|
|
|
+ # sort by time ascending
|
|
|
+ rows_sorted = sorted(rows, key=lambda r: r.get("ts", 0))
|
|
|
+ # running accumulators
|
|
|
+ conf_sum = [0.0 for _ in range(num_bins)]
|
|
|
+ acc_sum = [0.0 for _ in range(num_bins)]
|
|
|
+ cnt = [0 for _ in range(num_bins)]
|
|
|
+ cls_bins = {
|
|
|
+ "H": {
|
|
|
+ "p_sum": [0.0] * num_bins,
|
|
|
+ "y_sum": [0.0] * num_bins,
|
|
|
+ "cnt": [0] * num_bins,
|
|
|
+ },
|
|
|
+ "D": {
|
|
|
+ "p_sum": [0.0] * num_bins,
|
|
|
+ "y_sum": [0.0] * num_bins,
|
|
|
+ "cnt": [0] * num_bins,
|
|
|
+ },
|
|
|
+ "A": {
|
|
|
+ "p_sum": [0.0] * num_bins,
|
|
|
+ "y_sum": [0.0] * num_bins,
|
|
|
+ "cnt": [0] * num_bins,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ brier_sum = 0.0
|
|
|
+ logloss_sum = 0.0
|
|
|
+ last_snap_ts = rows_sorted[0].get("ts", 0)
|
|
|
+ period_secs = max(0, int(period_days)) * 86400
|
|
|
+ snaps = 0
|
|
|
+ n = 0
|
|
|
+ for i, r in enumerate(rows_sorted):
|
|
|
+ ph, pd, pa = r["p"]
|
|
|
+ oh, od, oa = _one_hot(r["outcome"])
|
|
|
+ # update accumulators
|
|
|
+ conf = max(ph, pd, pa)
|
|
|
+ pred_label = (
|
|
|
+ "H" if ph >= pd and ph >= pa else ("D" if pd >= pa else "A")
|
|
|
+ )
|
|
|
+ correct = (
|
|
|
+ 1.0
|
|
|
+ if (
|
|
|
+ (pred_label == "H" and oh == 1)
|
|
|
+ or (pred_label == "D" and od == 1)
|
|
|
+ or (pred_label == "A" and oa == 1)
|
|
|
+ )
|
|
|
+ else 0.0
|
|
|
+ )
|
|
|
+ bi = _bin_index(conf)
|
|
|
+ conf_sum[bi] += conf
|
|
|
+ acc_sum[bi] += correct
|
|
|
+ cnt[bi] += 1
|
|
|
+ brier_sum += (ph - oh) ** 2 + (pd - od) ** 2 + (pa - oa) ** 2
|
|
|
+ p_true = max(eps, ph if oh else (pd if od else pa))
|
|
|
+ logloss_sum += -math.log(p_true)
|
|
|
+ for lab, p_val, y_val in (
|
|
|
+ ("H", ph, oh),
|
|
|
+ ("D", pd, od),
|
|
|
+ ("A", pa, oa),
|
|
|
+ ):
|
|
|
+ bi2 = _bin_index(p_val)
|
|
|
+ cls_bins[lab]["p_sum"][bi2] += p_val
|
|
|
+ cls_bins[lab]["y_sum"][bi2] += y_val
|
|
|
+ cls_bins[lab]["cnt"][bi2] += 1
|
|
|
+ n += 1
|
|
|
+
|
|
|
+ # snapshot trigger
|
|
|
+ do_count = snapshot_every > 0 and (n % snapshot_every == 0)
|
|
|
+ ts_now = int(r.get("ts", 0))
|
|
|
+ do_period = period_secs > 0 and (
|
|
|
+ ts_now - last_snap_ts >= period_secs
|
|
|
+ )
|
|
|
+ limit_ok = max_snapshots is None or snaps < max_snapshots
|
|
|
+ if (do_count or do_period) and limit_ok:
|
|
|
+ # compute ece
|
|
|
+ ece = 0.0
|
|
|
+ for bi3 in range(num_bins):
|
|
|
+ if cnt[bi3] == 0:
|
|
|
+ continue
|
|
|
+ conf_avg = conf_sum[bi3] / cnt[bi3]
|
|
|
+ acc_avg = acc_sum[bi3] / cnt[bi3]
|
|
|
+ w = cnt[bi3] / max(1.0, float(n))
|
|
|
+ ece += w * abs(acc_avg - conf_avg)
|
|
|
+ # build per-class bins
|
|
|
+ bins_out: List[dict] = []
|
|
|
+ for lab in ("H", "D", "A"):
|
|
|
+ p_sum = cls_bins[lab]["p_sum"]
|
|
|
+ y_sum = cls_bins[lab]["y_sum"]
|
|
|
+ c_arr = cls_bins[lab]["cnt"]
|
|
|
+ for j in range(num_bins):
|
|
|
+ cval = c_arr[j]
|
|
|
+ if cval < min_per_bin:
|
|
|
+ continue
|
|
|
+ p_avg = p_sum[j] / cval
|
|
|
+ y_avg = y_sum[j] / cval
|
|
|
+ bins_out.append(
|
|
|
+ {
|
|
|
+ "group": gk,
|
|
|
+ "class": lab,
|
|
|
+ "bin": j,
|
|
|
+ "bins": num_bins,
|
|
|
+ "p_low": j / num_bins,
|
|
|
+ "p_high": (j + 1) / num_bins,
|
|
|
+ "avg_p": p_avg,
|
|
|
+ "emp_rate": y_avg,
|
|
|
+ "count": cval,
|
|
|
+ "source": source_kind,
|
|
|
+ "cutoff_ts": ts_now,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # write snapshot docs
|
|
|
+ if (
|
|
|
+ persist
|
|
|
+ and db
|
|
|
+ and not (skip_if_exists and (gk, ts_now) in exist_hist_keys)
|
|
|
+ ):
|
|
|
+ snap_id = f"{gk}:{source_kind}:snap:{ts_now}:{num_bins}"
|
|
|
+ docs.append(
|
|
|
+ Document(
|
|
|
+ id=snap_id,
|
|
|
+ kind="calibration_metrics_history",
|
|
|
+ data={
|
|
|
+ "group": gk,
|
|
|
+ "n": int(n),
|
|
|
+ "brier": brier_sum / max(1.0, float(n)),
|
|
|
+ "logloss": logloss_sum / max(1.0, float(n)),
|
|
|
+ "ece": ece,
|
|
|
+ "bins": num_bins,
|
|
|
+ "min_per_bin": int(min_per_bin),
|
|
|
+ "source": source_kind,
|
|
|
+ "cutoff_ts": ts_now,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ )
|
|
|
+ snap_metrics_written += 1
|
|
|
+ for b in bins_out:
|
|
|
+ cls_label = b["class"]
|
|
|
+ bin_idx = b["bin"]
|
|
|
+ doc_id = (
|
|
|
+ f"{gk}:{source_kind}:snap_bins:{cls_label}:"
|
|
|
+ f"{bin_idx}:{ts_now}:{num_bins}"
|
|
|
+ )
|
|
|
+ docs.append(
|
|
|
+ Document(
|
|
|
+ id=doc_id,
|
|
|
+ kind="calibration_bins_history",
|
|
|
+ data=b,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ snap_bins_written += 1
|
|
|
+ snaps += 1
|
|
|
+ last_snap_ts = ts_now
|
|
|
+
|
|
|
if persist and db and docs:
|
|
|
db.insert_many(docs)
|
|
|
if verbose:
|
|
|
metrics_cnt = sum(1 for d in docs if d.kind == "calibration_metrics")
|
|
|
bins_cnt = sum(1 for d in docs if d.kind == "calibration_bins")
|
|
|
+ hist_m_cnt = sum(
|
|
|
+ 1 for d in docs if d.kind == "calibration_metrics_history"
|
|
|
+ )
|
|
|
+ hist_b_cnt = sum(
|
|
|
+ 1 for d in docs if d.kind == "calibration_bins_history"
|
|
|
+ )
|
|
|
print(
|
|
|
- (f"[CAL] Persisted metrics={metrics_cnt} " f"bins={bins_cnt}"),
|
|
|
+ (
|
|
|
+ f"[CAL] Persisted metrics={metrics_cnt} bins={bins_cnt} "
|
|
|
+ f"snap_metrics={hist_m_cnt} snap_bins={hist_b_cnt}"
|
|
|
+ ),
|
|
|
flush=True,
|
|
|
)
|
|
|
|
|
|
@@ -327,5 +567,7 @@ class CalibrationAnalyzer(AnalyticsBase):
|
|
|
"groups": list(results.keys()),
|
|
|
"metrics": results,
|
|
|
"persisted": len(docs) if (persist and db) else 0,
|
|
|
+ "snap_metrics_written": snap_metrics_written,
|
|
|
+ "snap_bins_written": snap_bins_written,
|
|
|
"source": source_kind,
|
|
|
}
|