|
@@ -9,7 +9,7 @@ Purpose:
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
-from typing import Any, Dict, Iterable, Optional, Tuple
|
|
|
|
|
|
|
+from typing import Any, Dict, Iterable, Optional, Tuple, cast
|
|
|
from collections import defaultdict
|
|
from collections import defaultdict
|
|
|
import re
|
|
import re
|
|
|
import math
|
|
import math
|
|
@@ -43,9 +43,10 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
l2_base: L2 regularization on log base means (default 0.0).
|
|
l2_base: L2 regularization on log base means (default 0.0).
|
|
|
step_size: Learning rate for "coord" optimizer (default 0.1).
|
|
step_size: Learning rate for "coord" optimizer (default 0.1).
|
|
|
outer_iters: Max outer iterations for "coord" (default 10).
|
|
outer_iters: Max outer iterations for "coord" (default 10).
|
|
|
- history: "none" (default), "predictions" (per-match), or "snapshots".
|
|
|
|
|
|
|
+ history: "both" (default), "predictions" (per-match), "snapshots", or "none".
|
|
|
- predictions: persist pre-match H/D/A probs and mu/nu for each match.
|
|
- predictions: persist pre-match H/D/A probs and mu/nu for each match.
|
|
|
- snapshots: persist team parameter snapshots at cutoffs.
|
|
- snapshots: persist team parameter snapshots at cutoffs.
|
|
|
|
|
+ - both: do both predictions and snapshots in a single pass.
|
|
|
snapshot_every: For snapshots, persist every Nth match (default 10).
|
|
snapshot_every: For snapshots, persist every Nth match (default 10).
|
|
|
max_iters_history: Iterations to use for history fits (default max_iters//2).
|
|
max_iters_history: Iterations to use for history fits (default max_iters//2).
|
|
|
max_goals: Max goals cap for probability table (default 8).
|
|
max_goals: Max goals cap for probability table (default 8).
|
|
@@ -67,7 +68,7 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
l2_base: float = float(kwargs.get("l2_base", 0.0))
|
|
l2_base: float = float(kwargs.get("l2_base", 0.0))
|
|
|
step_size: float = float(kwargs.get("step_size", 0.1))
|
|
step_size: float = float(kwargs.get("step_size", 0.1))
|
|
|
outer_iters: int = int(kwargs.get("outer_iters", 10))
|
|
outer_iters: int = int(kwargs.get("outer_iters", 10))
|
|
|
- history: str = str(kwargs.get("history", "none"))
|
|
|
|
|
|
|
+ history: str = str(kwargs.get("history", "both")).lower()
|
|
|
snapshot_every: int = int(kwargs.get("snapshot_every", 10))
|
|
snapshot_every: int = int(kwargs.get("snapshot_every", 10))
|
|
|
max_iters_history: Optional[int] = kwargs.get("max_iters_history")
|
|
max_iters_history: Optional[int] = kwargs.get("max_iters_history")
|
|
|
if max_iters_history is not None:
|
|
if max_iters_history is not None:
|
|
@@ -76,6 +77,19 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
except (ValueError, TypeError):
|
|
except (ValueError, TypeError):
|
|
|
max_iters_history = None
|
|
max_iters_history = None
|
|
|
max_goals: int = int(kwargs.get("max_goals", 8))
|
|
max_goals: int = int(kwargs.get("max_goals", 8))
|
|
|
|
|
+ # Progress / flushing controls
|
|
|
|
|
+ verbose: bool = bool(kwargs.get("verbose", True))
|
|
|
|
|
+ progress_every: int = int(kwargs.get("progress_every", 100))
|
|
|
|
|
+ flush_every: int = int(kwargs.get("flush_every", 1000))
|
|
|
|
|
+ # Skip switch: avoid recomputing groups already up-to-date
|
|
|
|
|
+ skip_if_exists: bool = bool(kwargs.get("skip_if_exists", False))
|
|
|
|
|
+ # Safety clamps to avoid extreme values in history
|
|
|
|
|
+ param_min: float = float(kwargs.get("param_min", 0.3))
|
|
|
|
|
+ param_max: float = float(kwargs.get("param_max", 3.0))
|
|
|
|
|
+ base_min: float = float(kwargs.get("base_min", 0.3))
|
|
|
|
|
+ base_max: float = float(kwargs.get("base_max", 3.0))
|
|
|
|
|
+ mu_max_cap: float = float(kwargs.get("mu_max", 6.0))
|
|
|
|
|
+ min_history_matches: int = int(kwargs.get("min_history_matches", 3))
|
|
|
|
|
|
|
|
# Helpers
|
|
# Helpers
|
|
|
def _get_ts(match: dict) -> int:
|
|
def _get_ts(match: dict) -> int:
|
|
@@ -266,6 +280,20 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
for k in def_a:
|
|
for k in def_a:
|
|
|
def_a[k] /= gm
|
|
def_a[k] /= gm
|
|
|
|
|
|
|
|
|
|
+ def _clamp_params() -> None:
|
|
|
|
|
+ # Clamp parameters to reasonable ranges
|
|
|
|
|
+ nonlocal base_h, base_a
|
|
|
|
|
+ base_h = max(base_min, min(base_max, float(base_h)))
|
|
|
|
|
+ base_a = max(base_min, min(base_max, float(base_a)))
|
|
|
|
|
+ for k in att_h:
|
|
|
|
|
+ att_h[k] = max(param_min, min(param_max, float(att_h[k])))
|
|
|
|
|
+ for k in att_a:
|
|
|
|
|
+ att_a[k] = max(param_min, min(param_max, float(att_a[k])))
|
|
|
|
|
+ for k in def_h:
|
|
|
|
|
+ def_h[k] = max(param_min, min(param_max, float(def_h[k])))
|
|
|
|
|
+ for k in def_a:
|
|
|
|
|
+ def_a[k] = max(param_min, min(param_max, float(def_a[k])))
|
|
|
|
|
+
|
|
|
def _expected(r: dict) -> tuple[float, float]:
|
|
def _expected(r: dict) -> tuple[float, float]:
|
|
|
mu = base_h * att_h[r["home"]] * def_a[r["away"]]
|
|
mu = base_h * att_h[r["home"]] * def_a[r["away"]]
|
|
|
nu = base_a * att_a[r["away"]] * def_h[r["home"]]
|
|
nu = base_a * att_a[r["away"]] * def_h[r["home"]]
|
|
@@ -297,9 +325,15 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
den += r["w"] * mu
|
|
den += r["w"] * mu
|
|
|
if den > 0:
|
|
if den > 0:
|
|
|
factor = num / den
|
|
factor = num / den
|
|
|
|
|
+ # Dampen extreme single-step updates
|
|
|
|
|
+ if factor < 0.5:
|
|
|
|
|
+ factor = 0.5
|
|
|
|
|
+ elif factor > 2.0:
|
|
|
|
|
+ factor = 2.0
|
|
|
delta = max(delta, abs(1 - factor))
|
|
delta = max(delta, abs(1 - factor))
|
|
|
att_h[t] *= factor
|
|
att_h[t] *= factor
|
|
|
_normalize()
|
|
_normalize()
|
|
|
|
|
+ _clamp_params()
|
|
|
|
|
|
|
|
# Update attack_away
|
|
# Update attack_away
|
|
|
for t in teams:
|
|
for t in teams:
|
|
@@ -313,9 +347,14 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
den += r["w"] * nu
|
|
den += r["w"] * nu
|
|
|
if den > 0:
|
|
if den > 0:
|
|
|
factor = num / den
|
|
factor = num / den
|
|
|
|
|
+ if factor < 0.5:
|
|
|
|
|
+ factor = 0.5
|
|
|
|
|
+ elif factor > 2.0:
|
|
|
|
|
+ factor = 2.0
|
|
|
delta = max(delta, abs(1 - factor))
|
|
delta = max(delta, abs(1 - factor))
|
|
|
att_a[t] *= factor
|
|
att_a[t] *= factor
|
|
|
_normalize()
|
|
_normalize()
|
|
|
|
|
+ _clamp_params()
|
|
|
|
|
|
|
|
# Update defense_away (affects mu)
|
|
# Update defense_away (affects mu)
|
|
|
for t in teams:
|
|
for t in teams:
|
|
@@ -329,9 +368,14 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
den += r["w"] * mu
|
|
den += r["w"] * mu
|
|
|
if den > 0:
|
|
if den > 0:
|
|
|
factor = num / den
|
|
factor = num / den
|
|
|
|
|
+ if factor < 0.5:
|
|
|
|
|
+ factor = 0.5
|
|
|
|
|
+ elif factor > 2.0:
|
|
|
|
|
+ factor = 2.0
|
|
|
delta = max(delta, abs(1 - factor))
|
|
delta = max(delta, abs(1 - factor))
|
|
|
def_a[t] *= factor
|
|
def_a[t] *= factor
|
|
|
_normalize()
|
|
_normalize()
|
|
|
|
|
+ _clamp_params()
|
|
|
|
|
|
|
|
# Update defense_home (affects nu)
|
|
# Update defense_home (affects nu)
|
|
|
for t in teams:
|
|
for t in teams:
|
|
@@ -345,9 +389,14 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
den += r["w"] * nu
|
|
den += r["w"] * nu
|
|
|
if den > 0:
|
|
if den > 0:
|
|
|
factor = num / den
|
|
factor = num / den
|
|
|
|
|
+ if factor < 0.5:
|
|
|
|
|
+ factor = 0.5
|
|
|
|
|
+ elif factor > 2.0:
|
|
|
|
|
+ factor = 2.0
|
|
|
delta = max(delta, abs(1 - factor))
|
|
delta = max(delta, abs(1 - factor))
|
|
|
def_h[t] *= factor
|
|
def_h[t] *= factor
|
|
|
_normalize()
|
|
_normalize()
|
|
|
|
|
+ _clamp_params()
|
|
|
|
|
|
|
|
if delta < tol:
|
|
if delta < tol:
|
|
|
break
|
|
break
|
|
@@ -362,7 +411,7 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
log_base_a = math.log(max(1e-9, base_a))
|
|
log_base_a = math.log(max(1e-9, base_a))
|
|
|
|
|
|
|
|
def _sync_from_logs() -> None:
|
|
def _sync_from_logs() -> None:
|
|
|
- nonlocal base_h, base_a
|
|
|
|
|
|
|
+ nonlocal base_h, base_a, log_base_h, log_base_a
|
|
|
for t in teams:
|
|
for t in teams:
|
|
|
att_h[t] = math.exp(log_att_h[t])
|
|
att_h[t] = math.exp(log_att_h[t])
|
|
|
att_a[t] = math.exp(log_att_a[t])
|
|
att_a[t] = math.exp(log_att_a[t])
|
|
@@ -370,6 +419,15 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
def_a[t] = math.exp(log_def_a[t])
|
|
def_a[t] = math.exp(log_def_a[t])
|
|
|
base_h = math.exp(log_base_h)
|
|
base_h = math.exp(log_base_h)
|
|
|
base_a = math.exp(log_base_a)
|
|
base_a = math.exp(log_base_a)
|
|
|
|
|
+ _clamp_params()
|
|
|
|
|
+ # Reflect clamps back to logs
|
|
|
|
|
+ for t in teams:
|
|
|
|
|
+ log_att_h[t] = math.log(att_h[t])
|
|
|
|
|
+ log_att_a[t] = math.log(att_a[t])
|
|
|
|
|
+ log_def_h[t] = math.log(def_h[t])
|
|
|
|
|
+ log_def_a[t] = math.log(def_a[t])
|
|
|
|
|
+ log_base_h = math.log(base_h)
|
|
|
|
|
+ log_base_a = math.log(base_a)
|
|
|
|
|
|
|
|
def _center_logs() -> None:
|
|
def _center_logs() -> None:
|
|
|
# Enforce identifiability: mean of logs = 0 per block
|
|
# Enforce identifiability: mean of logs = 0 per block
|
|
@@ -512,7 +570,67 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
all_params: Dict[str, Dict[str, Dict[str, float]]] = {}
|
|
all_params: Dict[str, Dict[str, Dict[str, float]]] = {}
|
|
|
docs: list[Document] = []
|
|
docs: list[Document] = []
|
|
|
total_matches = 0
|
|
total_matches = 0
|
|
|
|
|
+ preds_written = 0
|
|
|
|
|
+ snaps_written = 0
|
|
|
|
|
+ skipped_groups = 0
|
|
|
|
|
+
|
|
|
|
|
+ def _maybe_flush(reason: str = "") -> None:
|
|
|
|
|
+ nonlocal docs
|
|
|
|
|
+ if persist and db and docs and len(docs) >= max(1, flush_every):
|
|
|
|
|
+ try:
|
|
|
|
|
+ db.insert_many(docs)
|
|
|
|
|
+ if verbose:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f"[DC] Flushed {len(docs)} docs to DB"
|
|
|
|
|
+ + (f" ({reason})" if reason else ""),
|
|
|
|
|
+ flush=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ finally:
|
|
|
|
|
+ docs = []
|
|
|
|
|
+
|
|
|
|
|
+ if verbose:
|
|
|
|
|
+ print(f"[DC] Groups to fit: {len(groups)}", flush=True)
|
|
|
|
|
+
|
|
|
for gk, rows in groups.items():
|
|
for gk, rows in groups.items():
|
|
|
|
|
+ # Optional: skip group if summary doc exists with same match count
|
|
|
|
|
+ if skip_if_exists and persist and db:
|
|
|
|
|
+ try:
|
|
|
|
|
+ db_any = cast(Any, db)
|
|
|
|
|
+ existing = None
|
|
|
|
|
+ docs_iter = db_any.find(
|
|
|
|
|
+ "dc_params", projection={"_id": 0}, limit=None
|
|
|
|
|
+ )
|
|
|
|
|
+ if docs_iter is not None:
|
|
|
|
|
+ for d in docs_iter:
|
|
|
|
|
+ if not isinstance(d, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ if d.get("group") == gk and d.get("summary") is True:
|
|
|
|
|
+ existing = d
|
|
|
|
|
+ break
|
|
|
|
|
+ if isinstance(existing, dict):
|
|
|
|
|
+ try:
|
|
|
|
|
+ prev_matches = int(existing.get("matches", -1))
|
|
|
|
|
+ except (ValueError, TypeError):
|
|
|
|
|
+ prev_matches = -1
|
|
|
|
|
+ if prev_matches == len(rows):
|
|
|
|
|
+ if verbose:
|
|
|
|
|
+ print(
|
|
|
|
|
+ (
|
|
|
|
|
+ f"[DC] Skip group '{gk}' "
|
|
|
|
|
+ f"(skip_if_exists; matches={prev_matches})"
|
|
|
|
|
+ ),
|
|
|
|
|
+ flush=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ skipped_groups += 1
|
|
|
|
|
+ continue
|
|
|
|
|
+ except (RuntimeError, ValueError, TypeError):
|
|
|
|
|
+ # Diagnostics only: if existence check fails, proceed normally
|
|
|
|
|
+ pass
|
|
|
|
|
+ if verbose:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f"[DC] Processing group '{gk}' with {len(rows)} matches",
|
|
|
|
|
+ flush=True,
|
|
|
|
|
+ )
|
|
|
p, stats = _fit_group(rows)
|
|
p, stats = _fit_group(rows)
|
|
|
all_params[gk] = p
|
|
all_params[gk] = p
|
|
|
total_matches += int(stats.get("matches", 0))
|
|
total_matches += int(stats.get("matches", 0))
|
|
@@ -529,6 +647,7 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
},
|
|
},
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ _maybe_flush("dc_params")
|
|
|
# Optionally also persist a group-level summary doc
|
|
# Optionally also persist a group-level summary doc
|
|
|
docs.append(
|
|
docs.append(
|
|
|
Document(
|
|
Document(
|
|
@@ -545,14 +664,18 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
},
|
|
},
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ _maybe_flush("dc_params_summary")
|
|
|
|
|
|
|
|
- # Optional history: pre-match predictions or parameter snapshots
|
|
|
|
|
- if history in ("predictions", "snapshots") and persist and db:
|
|
|
|
|
|
|
+ # Optional history: pre-match predictions and/or parameter snapshots
|
|
|
|
|
+ do_preds = history in ("predictions", "both")
|
|
|
|
|
+ do_snaps = history in ("snapshots", "both")
|
|
|
|
|
+ if (do_preds or do_snaps) and persist and db:
|
|
|
# Helper for DC-phi adjusted joint probabilities
|
|
# Helper for DC-phi adjusted joint probabilities
|
|
|
def _probs_hda(
|
|
def _probs_hda(
|
|
|
mu: float, nu: float, rho: float
|
|
mu: float, nu: float, rho: float
|
|
|
) -> tuple[float, float, float]:
|
|
) -> tuple[float, float, float]:
|
|
|
def _pois(k: int, lam: float) -> float:
|
|
def _pois(k: int, lam: float) -> float:
|
|
|
|
|
+ lam = max(1e-9, float(lam))
|
|
|
return math.exp(k * math.log(lam) - lam - math.lgamma(k + 1))
|
|
return math.exp(k * math.log(lam) - lam - math.lgamma(k + 1))
|
|
|
|
|
|
|
|
# Use same phi as in likelihood
|
|
# Use same phi as in likelihood
|
|
@@ -606,7 +729,7 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
|
|
|
|
|
for i, r in enumerate(rows_sorted):
|
|
for i, r in enumerate(rows_sorted):
|
|
|
prior = rows_sorted[:i]
|
|
prior = rows_sorted[:i]
|
|
|
- if not prior:
|
|
|
|
|
|
|
+ if not prior or len(prior) < max(0, int(min_history_matches)):
|
|
|
continue
|
|
continue
|
|
|
p_h, st_h = _fit_group(
|
|
p_h, st_h = _fit_group(
|
|
|
prior,
|
|
prior,
|
|
@@ -642,8 +765,10 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
* att_a_map.get(r["away"], 1.0)
|
|
* att_a_map.get(r["away"], 1.0)
|
|
|
* def_h_map.get(r["home"], 1.0)
|
|
* def_h_map.get(r["home"], 1.0)
|
|
|
)
|
|
)
|
|
|
|
|
+ mu = max(1e-9, min(mu_max_cap, float(mu)))
|
|
|
|
|
+ nu = max(1e-9, min(mu_max_cap, float(nu)))
|
|
|
|
|
|
|
|
- if history == "predictions":
|
|
|
|
|
|
|
+ if do_preds:
|
|
|
ph, pd, pa = _probs_hda(mu, nu, rho_h)
|
|
ph, pd, pa = _probs_hda(mu, nu, rho_h)
|
|
|
docs.append(
|
|
docs.append(
|
|
|
Document(
|
|
Document(
|
|
@@ -668,8 +793,10 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
},
|
|
},
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ preds_written += 1
|
|
|
|
|
+ _maybe_flush("dc_predictions")
|
|
|
|
|
|
|
|
- if history == "snapshots" and (i % max(1, snapshot_every) == 0):
|
|
|
|
|
|
|
+ if do_snaps and (i % max(1, snapshot_every) == 0):
|
|
|
cut_ts = int(r["ts"])
|
|
cut_ts = int(r["ts"])
|
|
|
for team_id, vals in p_h.items():
|
|
for team_id, vals in p_h.items():
|
|
|
docs.append(
|
|
docs.append(
|
|
@@ -700,13 +827,30 @@ class DixonColesAnalyzer(AnalyticsBase):
|
|
|
},
|
|
},
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ snaps_written += 1
|
|
|
|
|
+ _maybe_flush("dc_params_history")
|
|
|
|
|
+
|
|
|
|
|
+ if verbose and (i + 1) % max(1, progress_every) == 0:
|
|
|
|
|
+ print(
|
|
|
|
|
+ (
|
|
|
|
|
+ f"[DC] Group '{gk}' history progress: "
|
|
|
|
|
+ f"{i + 1}/{len(rows_sorted)} "
|
|
|
|
|
+ f"(preds={preds_written}, snaps={snaps_written})"
|
|
|
|
|
+ ),
|
|
|
|
|
+ flush=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
if persist and db and docs:
|
|
if persist and db and docs:
|
|
|
db.insert_many(docs)
|
|
db.insert_many(docs)
|
|
|
|
|
+ if verbose:
|
|
|
|
|
+ print(f"[DC] Final flush {len(docs)} docs", flush=True)
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
"groups": list(all_params.keys()),
|
|
"groups": list(all_params.keys()),
|
|
|
"params": all_params,
|
|
"params": all_params,
|
|
|
"matches_used": total_matches,
|
|
"matches_used": total_matches,
|
|
|
"persisted": len(docs) if docs else 0,
|
|
"persisted": len(docs) if docs else 0,
|
|
|
|
|
+ "predictions_written": preds_written,
|
|
|
|
|
+ "snapshots_written": snaps_written,
|
|
|
|
|
+ "groups_skipped": skipped_groups,
|
|
|
}
|
|
}
|