Ver código fonte

为Dixon-Coles分析器添加环境变量配置,支持时间衰减和参数持久化;优化计算逻辑,增强历史记录功能,支持预测和快照模式。

admin 1 mês atrás
pai
commit
65e72b89ba
2 arquivos alterados com 429 adições e 88 exclusões
  1. 54 1
      scripts/run_analyzers.py
  2. 375 87
      src/databank/analytics/dixon_coles.py

+ 54 - 1
scripts/run_analyzers.py

@@ -88,13 +88,59 @@ def main() -> None:
         season_mc_cls(),
     ]
 
+    # Prepare optional DC config from environment
+    def _env_float(name: str, default: float) -> float:
+        try:
+            return (
+                float(os.getenv(name, "")) if os.getenv(name) is not None else default
+            )
+        except ValueError:
+            return default
+
+    def _env_int(name: str, default: int) -> int:
+        try:
+            return int(os.getenv(name, "")) if os.getenv(name) is not None else default
+        except ValueError:
+            return default
+
+    def _env_rho_range(name: str, default: tuple[float, float]) -> tuple[float, float]:
+        s = os.getenv(name)
+        if not s:
+            return default
+        try:
+            lo_str, hi_str = s.split(",", 1)
+            return float(lo_str), float(hi_str)
+        except (ValueError, TypeError):
+            return default
+
+    dc_kwargs = {
+        "halflife_days": _env_float("DATABANK_DC_HALFLIFE_DAYS", 180.0),
+        "rho_range": _env_rho_range("DATABANK_DC_RHO_RANGE", (-0.3, 0.3)),
+        "rho_step": _env_float("DATABANK_DC_RHO_STEP", 0.01),
+        "max_iters": _env_int("DATABANK_DC_MAX_ITERS", 20),
+        "tol": _env_float("DATABANK_DC_TOL", 1e-4),
+    }
+
+    # Optional history configuration for DC
+    history_mode = os.getenv("DATABANK_DC_HISTORY", "none").strip().lower()
+    if history_mode in {"none", "predictions", "snapshots"}:
+        dc_kwargs["history"] = history_mode
+    dc_kwargs["snapshot_every"] = _env_int("DATABANK_DC_SNAPSHOT_EVERY", 10)
+    dc_kwargs["max_iters_history"] = _env_int("DATABANK_DC_MAX_ITERS_HISTORY", 10)
+    dc_kwargs["max_goals"] = _env_int("DATABANK_DC_MAX_GOALS", 8)
+
     for analyzer in analyzers:
         print(f"Running analyzer: {analyzer.__class__.__name__}")
         try:
             analyzer.prepare(data)
             analyzer.validate(data)
             transformed = analyzer.transform(data)
-            result = analyzer.compute(transformed, db=db, persist=True)
+            if isinstance(analyzer, dixon_coles_cls):
+                # Pass DC-specific kwargs from environment
+                result = analyzer.compute(transformed, db=db, persist=True, **dc_kwargs)
+                print("    DC config:", dc_kwargs)
+            else:
+                result = analyzer.compute(transformed, db=db, persist=True)
             analyzer.finalize(result)
             print(f" -> Done: {analyzer.__class__.__name__}")
 
@@ -124,6 +170,13 @@ def main() -> None:
                     print(
                         f"    DC matches_used: {mu}; persisted docs in this run: {persisted}"
                     )
+                # Optional extra collections
+                preds_cnt = _safe_count(db, "dc_predictions")
+                snaps_cnt = _safe_count(db, "dc_params_history")
+                if preds_cnt:
+                    print("    DC predictions count:", preds_cnt)
+                if snaps_cnt:
+                    print("    DC params history count:", snaps_cnt)
         except NotImplementedError as exc:
             print(f" -> Skipped (not implemented): {exc}")
         except (RuntimeError, ValueError) as exc:  # pragma: no cover - diagnostics only

+ 375 - 87
src/databank/analytics/dixon_coles.py

@@ -13,10 +13,11 @@ from typing import Any, Dict, Iterable, Optional, Tuple
 from collections import defaultdict
 import re
 import math
+import time as _t
+import calendar as _cal
 
 from databank.core.models import Document
 from databank.db.base import BaseDB
-
 from .base import AnalyticsBase
 
 
@@ -24,26 +25,57 @@ class DixonColesAnalyzer(AnalyticsBase):
     """Estimate DC parameters with time-decayed likelihood and rho correlation."""
 
     def compute(self, data: Any, **kwargs: Any) -> Any:  # noqa: D401
-        """Fit/estimate Poisson-style parameters with DC correlation and decay.
+        """Fit Poisson-style parameters with DC correlation and time decay.
 
         Args:
-            data: Iterable of match-like docs (same shape as Elo input).
-            **kwargs: group_by ("league_season"|"global"), persist (bool), db (BaseDB),
-                halflife_days (float, default 180), rho_range (tuple[float,float], default (-0.3,0.3)),
-                rho_step (float, default 0.01), max_iters (int, default 20), tol (float, default 1e-4).
-
-        Returns:
-            dict with per-group parameters summary. If persist=True, writes
-            documents to 'dc_params'.
+            data: Iterable of match-like docs (same as Elo input).
+            group_by: "league_season" (default) or "global".
+            persist: Whether to persist results (default True).
+            db: Database handle when persist=True.
+            halflife_days: Time-decay half-life in days (default 180).
+            rho_range: Search interval for rho, e.g. (-0.3, 0.3).
+            rho_step: Grid step for rho search (default 0.01).
+            max_iters: IPF max iterations (default 20).
+            tol: IPF convergence tolerance (default 1e-4).
+            optimize: "ipf" (default) or "coord" (joint log-domain updates).
+            l2_attack: L2 regularization on log attack (default 0.0).
+            l2_defense: L2 regularization on log defense (default 0.0).
+            l2_base: L2 regularization on log base means (default 0.0).
+            step_size: Learning rate for "coord" optimizer (default 0.1).
+            outer_iters: Max outer iterations for "coord" (default 10).
+            history: "none" (default), "predictions" (per-match), or "snapshots".
+                - predictions: persist pre-match H/D/A probs and mu/nu for each match.
+                - snapshots: persist team parameter snapshots at cutoffs.
+            snapshot_every: For snapshots, persist every Nth match (default 10).
+            max_iters_history: Iterations to use for history fits (default max_iters//2).
+            max_goals: Max goals cap for probability table (default 8).
         """
+
         group_by = str(kwargs.get("group_by", "league_season"))
         persist = bool(kwargs.get("persist", True))
         db: Optional[BaseDB] = kwargs.get("db")
         halflife_days: float = float(kwargs.get("halflife_days", 180.0))
-        rho_range: Tuple[float, float] = tuple(kwargs.get("rho_range", (-0.3, 0.3)))  # type: ignore[assignment]
+        rho_range: Tuple[float, float] = tuple(
+            kwargs.get("rho_range", (-0.3, 0.3))
+        )  # type: ignore[assignment]
         rho_step: float = float(kwargs.get("rho_step", 0.01))
         max_iters: int = int(kwargs.get("max_iters", 20))
         tol: float = float(kwargs.get("tol", 1e-4))
+        optimize: str = str(kwargs.get("optimize", "ipf"))
+        l2_attack: float = float(kwargs.get("l2_attack", 0.0))
+        l2_defense: float = float(kwargs.get("l2_defense", 0.0))
+        l2_base: float = float(kwargs.get("l2_base", 0.0))
+        step_size: float = float(kwargs.get("step_size", 0.1))
+        outer_iters: int = int(kwargs.get("outer_iters", 10))
+        history: str = str(kwargs.get("history", "none"))
+        snapshot_every: int = int(kwargs.get("snapshot_every", 10))
+        max_iters_history: Optional[int] = kwargs.get("max_iters_history")
+        if max_iters_history is not None:
+            try:
+                max_iters_history = int(max_iters_history)
+            except (ValueError, TypeError):
+                max_iters_history = None
+        max_goals: int = int(kwargs.get("max_goals", 8))
 
         # Helpers
         def _get_ts(match: dict) -> int:
@@ -64,11 +96,8 @@ class DixonColesAnalyzer(AnalyticsBase):
                 try:
                     y, m, d = [int(x) for x in re.split(r"[-/]", date_s.strip())[:3]]
                     hh, mm = [int(x) for x in time_s.strip().split(":")[:2]]
-                    # naive epoch assume UTC; avoid importing datetime to keep deps minimal
-                    # Use a rough conversion: days since epoch * 86400 + seconds
+                    # naive epoch assume UTC; rough conversion using calendar.timegm
                     # Here we fallback to 0 if parsing fails in any step.
-                    import time as _t
-                    import calendar as _cal
 
                     try:
                         struct = _t.struct_time((y, m, d, hh, mm, 0, 0, 0, 0))
@@ -179,12 +208,17 @@ class DixonColesAnalyzer(AnalyticsBase):
                 }
             )
 
-        def _fit_group(rows: list[dict]) -> tuple[dict, dict]:
+        def _fit_group(
+            rows: list[dict],
+            max_iters_override: Optional[int] = None,
+            outer_iters_override: Optional[int] = None,
+            ref_ts: Optional[int] = None,
+        ) -> tuple[dict, dict]:
             if not rows:
                 return {}, {"matches": 0}
 
-            # Time decay weights
-            max_ts = max(r["ts"] for r in rows)
+            # Time decay weights (use provided ref_ts if given; else latest in rows)
+            max_ts = int(ref_ts) if ref_ts is not None else max(r["ts"] for r in rows)
             lam = math.log(2.0) / max(1.0, halflife_days)
             for r in rows:
                 age_days = max(0.0, (max_ts - r["ts"]) / 86400.0)
@@ -237,75 +271,173 @@ class DixonColesAnalyzer(AnalyticsBase):
                 nu = base_a * att_a[r["away"]] * def_h[r["home"]]
                 return (max(1e-9, mu), max(1e-9, nu))
 
-            # IPF-like alternating updates
-            for _ in range(max_iters):
-                delta = 0.0
-                # Update attack_home
-                for t in teams:
-                    num = 0.0
-                    den = 0.0
-                    for r in rows:
-                        if r["home"] != t:
-                            continue
-                        mu, _ = _expected(r)
-                        num += r["w"] * r["hs"]
-                        den += r["w"] * mu
-                    if den > 0:
-                        factor = num / den
-                        delta = max(delta, abs(1 - factor))
-                        att_h[t] *= factor
-                _normalize()
-
-                # Update attack_away
-                for t in teams:
-                    num = 0.0
-                    den = 0.0
-                    for r in rows:
-                        if r["away"] != t:
-                            continue
-                        _, nu = _expected(r)
-                        num += r["w"] * r["as"]
-                        den += r["w"] * nu
-                    if den > 0:
-                        factor = num / den
-                        delta = max(delta, abs(1 - factor))
-                        att_a[t] *= factor
-                _normalize()
-
-                # Update defense_away (affects mu)
-                for t in teams:
-                    num = 0.0
-                    den = 0.0
-                    for r in rows:
-                        if r["away"] != t:
-                            continue
-                        mu, _ = _expected(r)
-                        num += r["w"] * r["hs"]
-                        den += r["w"] * mu
-                    if den > 0:
-                        factor = num / den
-                        delta = max(delta, abs(1 - factor))
-                        def_a[t] *= factor
-                _normalize()
-
-                # Update defense_home (affects nu)
-                for t in teams:
-                    num = 0.0
-                    den = 0.0
+            # Select iteration caps (allow overrides for history runs)
+            local_max_iters = (
+                int(max_iters_override) if max_iters_override is not None else max_iters
+            )
+            local_outer_iters = (
+                int(outer_iters_override)
+                if outer_iters_override is not None
+                else outer_iters
+            )
+
+            if optimize == "ipf":
+                # IPF-like alternating updates
+                for _ in range(local_max_iters):
+                    delta = 0.0
+                    # Update attack_home
+                    for t in teams:
+                        num = 0.0
+                        den = 0.0
+                        for r in rows:
+                            if r["home"] != t:
+                                continue
+                            mu, _ = _expected(r)
+                            num += r["w"] * r["hs"]
+                            den += r["w"] * mu
+                        if den > 0:
+                            factor = num / den
+                            delta = max(delta, abs(1 - factor))
+                            att_h[t] *= factor
+                    _normalize()
+
+                    # Update attack_away
+                    for t in teams:
+                        num = 0.0
+                        den = 0.0
+                        for r in rows:
+                            if r["away"] != t:
+                                continue
+                            _, nu = _expected(r)
+                            num += r["w"] * r["as"]
+                            den += r["w"] * nu
+                        if den > 0:
+                            factor = num / den
+                            delta = max(delta, abs(1 - factor))
+                            att_a[t] *= factor
+                    _normalize()
+
+                    # Update defense_away (affects mu)
+                    for t in teams:
+                        num = 0.0
+                        den = 0.0
+                        for r in rows:
+                            if r["away"] != t:
+                                continue
+                            mu, _ = _expected(r)
+                            num += r["w"] * r["hs"]
+                            den += r["w"] * mu
+                        if den > 0:
+                            factor = num / den
+                            delta = max(delta, abs(1 - factor))
+                            def_a[t] *= factor
+                    _normalize()
+
+                    # Update defense_home (affects nu)
+                    for t in teams:
+                        num = 0.0
+                        den = 0.0
+                        for r in rows:
+                            if r["home"] != t:
+                                continue
+                            _, nu = _expected(r)
+                            num += r["w"] * r["as"]
+                            den += r["w"] * nu
+                        if den > 0:
+                            factor = num / den
+                            delta = max(delta, abs(1 - factor))
+                            def_h[t] *= factor
+                    _normalize()
+
+                    if delta < tol:
+                        break
+            else:
+                # Coordinate updates in log-domain with L2 regularization
+                # Initialize logs at 0 (since params at 1.0)
+                log_att_h = {t: 0.0 for t in teams}
+                log_att_a = {t: 0.0 for t in teams}
+                log_def_h = {t: 0.0 for t in teams}
+                log_def_a = {t: 0.0 for t in teams}
+                log_base_h = math.log(max(1e-9, base_h))
+                log_base_a = math.log(max(1e-9, base_a))
+
+                def _sync_from_logs() -> None:
+                    nonlocal base_h, base_a
+                    for t in teams:
+                        att_h[t] = math.exp(log_att_h[t])
+                        att_a[t] = math.exp(log_att_a[t])
+                        def_h[t] = math.exp(log_def_h[t])
+                        def_a[t] = math.exp(log_def_a[t])
+                    base_h = math.exp(log_base_h)
+                    base_a = math.exp(log_base_a)
+
+                def _center_logs() -> None:
+                    # Enforce identifiability: mean of logs = 0 per block
+                    def _center(d: dict[str, float]) -> None:
+                        if len(d) == 0:
+                            return
+                        mean = sum(d.values()) / len(d)
+                        for k in d:
+                            d[k] -= mean
+
+                    _center(log_att_h)
+                    _center(log_att_a)
+                    _center(log_def_h)
+                    _center(log_def_a)
+
+                for _ in range(max(1, local_outer_iters)):
+                    # attack_home gradients
+                    for t in teams:
+                        grad = -l2_attack * log_att_h[t]
+                        for r in rows:
+                            if r["home"] != t:
+                                continue
+                            mu, _ = _expected(r)
+                            grad += r["w"] * (r["hs"] - mu)
+                        log_att_h[t] += step_size * grad / (sum_w + 1e-9)
+
+                    # attack_away gradients
+                    for t in teams:
+                        grad = -l2_attack * log_att_a[t]
+                        for r in rows:
+                            if r["away"] != t:
+                                continue
+                            _, nu = _expected(r)
+                            grad += r["w"] * (r["as"] - nu)
+                        log_att_a[t] += step_size * grad / (sum_w + 1e-9)
+
+                    # defense_away gradients (mu)
+                    for t in teams:
+                        grad = -l2_defense * log_def_a[t]
+                        for r in rows:
+                            if r["away"] != t:
+                                continue
+                            mu, _ = _expected(r)
+                            grad += r["w"] * (r["hs"] - mu)
+                        log_def_a[t] += step_size * grad / (sum_w + 1e-9)
+
+                    # defense_home gradients (nu)
+                    for t in teams:
+                        grad = -l2_defense * log_def_h[t]
+                        for r in rows:
+                            if r["home"] != t:
+                                continue
+                            _, nu = _expected(r)
+                            grad += r["w"] * (r["as"] - nu)
+                        log_def_h[t] += step_size * grad / (sum_w + 1e-9)
+
+                    # base means
+                    grad_bh = -l2_base * log_base_h
+                    grad_ba = -l2_base * log_base_a
                     for r in rows:
-                        if r["home"] != t:
-                            continue
-                        _, nu = _expected(r)
-                        num += r["w"] * r["as"]
-                        den += r["w"] * nu
-                    if den > 0:
-                        factor = num / den
-                        delta = max(delta, abs(1 - factor))
-                        def_h[t] *= factor
-                _normalize()
-
-                if delta < tol:
-                    break
+                        mu, nu = _expected(r)
+                        grad_bh += r["w"] * (r["hs"] - mu)
+                        grad_ba += r["w"] * (r["as"] - nu)
+                    log_base_h += step_size * grad_bh / (sum_w + 1e-9)
+                    log_base_a += step_size * grad_ba / (sum_w + 1e-9)
+
+                    _center_logs()
+                    _sync_from_logs()
 
             # Given parameters, grid-search rho for DC correlation
             def _dc_phi(hg: int, ag: int, mu: float, nu: float, rho: float) -> float:
@@ -324,7 +456,8 @@ class DixonColesAnalyzer(AnalyticsBase):
                 s = 0.0
                 for r in rows:
                     mu, nu = _expected(r)
-                    # Poisson log pmf (ignoring constant factorial by Stirling or exact; include exact via math.lgamma)
+                    # Poisson log pmf (ignore constant factorial)
+                    # Use math.lgamma for exact factorial term when needed
                     x = r["hs"]
                     y = r["as"]
                     log_px = x * math.log(mu) - mu - math.lgamma(x + 1)
@@ -413,6 +546,161 @@ class DixonColesAnalyzer(AnalyticsBase):
                     )
                 )
 
+            # Optional history: pre-match predictions or parameter snapshots
+            if history in ("predictions", "snapshots") and persist and db:
+                # Helper for DC-phi adjusted joint probabilities
+                def _probs_hda(
+                    mu: float, nu: float, rho: float
+                ) -> tuple[float, float, float]:
+                    def _pois(k: int, lam: float) -> float:
+                        return math.exp(k * math.log(lam) - lam - math.lgamma(k + 1))
+
+                    # Use same phi as in likelihood
+                    def _phi(x: int, y: int) -> float:
+                        return (
+                            1.0
+                            if (x > 1 or y > 1)
+                            else (
+                                max(1e-9, 1.0 - mu * nu * rho)
+                                if (x == 0 and y == 0)
+                                else (
+                                    max(1e-9, 1.0 + mu * rho)
+                                    if (x == 0 and y == 1)
+                                    else (
+                                        max(1e-9, 1.0 + nu * rho)
+                                        if (x == 1 and y == 0)
+                                        else max(1e-9, 1.0 - rho)
+                                    )
+                                )
+                            )
+                        )
+
+                    total = 0.0
+                    ph, pd, pa = 0.0, 0.0, 0.0
+                    cap = max(0, int(max_goals))
+                    for x in range(cap + 1):
+                        px = _pois(x, mu)
+                        for y in range(cap + 1):
+                            py = _pois(y, nu)
+                            pr = px * py * _phi(x, y)
+                            total += pr
+                            if x > y:
+                                ph += pr
+                            elif x == y:
+                                pd += pr
+                            else:
+                                pa += pr
+                    if total <= 0:
+                        return 0.0, 0.0, 0.0
+                    inv = 1.0 / total
+                    return ph * inv, pd * inv, pa * inv
+
+                # Iterate matches chronologically; fit using only prior matches
+                rows_sorted = sorted(rows, key=lambda r: r["ts"])  # ascending
+                hist_iters = (
+                    int(max_iters_history)
+                    if isinstance(max_iters_history, int)
+                    else max(5, max_iters // 2)
+                )
+                hist_outer = max(1, outer_iters // 2)
+
+                for i, r in enumerate(rows_sorted):
+                    prior = rows_sorted[:i]
+                    if not prior:
+                        continue
+                    p_h, st_h = _fit_group(
+                        prior,
+                        max_iters_override=hist_iters,
+                        outer_iters_override=hist_outer,
+                        ref_ts=int(r["ts"]),
+                    )
+
+                    # Build parameter maps
+                    base_h = float(st_h.get("base_home", 1.0))
+                    base_a = float(st_h.get("base_away", 1.0))
+                    rho_h = float(st_h.get("rho", 0.0))
+                    att_h_map = {
+                        k: float(v.get("attack_home", 1.0)) for k, v in p_h.items()
+                    }
+                    att_a_map = {
+                        k: float(v.get("attack_away", 1.0)) for k, v in p_h.items()
+                    }
+                    def_h_map = {
+                        k: float(v.get("defense_home", 1.0)) for k, v in p_h.items()
+                    }
+                    def_a_map = {
+                        k: float(v.get("defense_away", 1.0)) for k, v in p_h.items()
+                    }
+
+                    mu = (
+                        base_h
+                        * att_h_map.get(r["home"], 1.0)
+                        * def_a_map.get(r["away"], 1.0)
+                    )
+                    nu = (
+                        base_a
+                        * att_a_map.get(r["away"], 1.0)
+                        * def_h_map.get(r["home"], 1.0)
+                    )
+
+                    if history == "predictions":
+                        ph, pd, pa = _probs_hda(mu, nu, rho_h)
+                        docs.append(
+                            Document(
+                                id=f"{gk}:{r['ts']}:{r['home']}:{r['away']}",
+                                kind="dc_predictions",
+                                data={
+                                    "group": gk,
+                                    "home": r["home"],
+                                    "away": r["away"],
+                                    "cutoff_ts": int(r["ts"]),
+                                    "train_matches": int(len(prior)),
+                                    "mu": float(mu),
+                                    "nu": float(nu),
+                                    "rho": float(rho_h),
+                                    "p_home": float(ph),
+                                    "p_draw": float(pd),
+                                    "p_away": float(pa),
+                                    "observed_h": int(r["hs"]),
+                                    "observed_a": int(r["as"]),
+                                    "model": "dixon_coles",
+                                    "halflife_days": float(halflife_days),
+                                },
+                            )
+                        )
+
+                    if history == "snapshots" and (i % max(1, snapshot_every) == 0):
+                        cut_ts = int(r["ts"])
+                        for team_id, vals in p_h.items():
+                            docs.append(
+                                Document(
+                                    id=f"{gk}:{cut_ts}:{team_id}",
+                                    kind="dc_params_history",
+                                    data={
+                                        "group": gk,
+                                        "team_id": team_id,
+                                        "cutoff_ts": cut_ts,
+                                        "train_matches": int(len(prior)),
+                                        "attack_home": float(
+                                            vals.get("attack_home", 1.0)
+                                        ),
+                                        "attack_away": float(
+                                            vals.get("attack_away", 1.0)
+                                        ),
+                                        "defense_home": float(
+                                            vals.get("defense_home", 1.0)
+                                        ),
+                                        "defense_away": float(
+                                            vals.get("defense_away", 1.0)
+                                        ),
+                                        "league_home_avg": float(base_h),
+                                        "league_away_avg": float(base_a),
+                                        "rho": float(rho_h),
+                                        "halflife_days": float(halflife_days),
+                                    },
+                                )
+                            )
+
         if persist and db and docs:
             db.insert_many(docs)