Files
Masterprojekt-Campusnetz/Stochastisches_Modell.py
2025-12-06 16:19:34 +01:00

122 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sympy as sp
from dataclasses import dataclass, field
from typing import Dict, Tuple
@dataclass
class StochastischesModellApriori:
sigma_obs: Iterable[float] # σ_i
group_ids: Iterable[int] # Gruppenzugehörigkeit der i-ten Beobachtung
sigma0_sq_groups: Dict[int, float] = field(default_factory=dict)
def __post_init__(self):
# In sympy-Objekte konvertieren
self.sigma_obs = sp.Matrix(list(self.sigma_obs)) # Spaltenvektor
self.group_ids = sp.Matrix(list(self.group_ids)) # Spaltenvektor
if self.sigma_obs.rows != self.group_ids.rows:
raise ValueError("sigma_obs und group_ids müssen gleich viele Einträge haben.")
# Fehlende Gruppen mit σ_0j^2 = 1.0 initialisieren
unique_groups = sorted({int(g) for g in self.group_ids})
for g in unique_groups:
if g not in self.sigma0_sq_groups:
self.sigma0_sq_groups[g] = 1.0
@property
def n_obs(self) -> int:
return int(self.sigma_obs.rows)
def build_Qll_P(self) -> Tuple[sp.Matrix, sp.Matrix]:
n = self.n_obs
Q_ll = sp.zeros(n, n)
P = sp.zeros(n, n)
for i in range(n):
sigma_i = self.sigma_obs[i, 0]
g = int(self.group_ids[i, 0])
sigma0_sq = self.sigma0_sq_groups[g]
q_ii = sigma_i**2
Q_ll[i, i] = q_ii
P[i, i] = 1 / (sigma0_sq * q_ii)
return Q_ll, P
@staticmethod
def _redundanz_pro_beobachtung(A: sp.Matrix, P: sp.Matrix) -> sp.Matrix:
n_obs = P.rows
n_param = A.cols
# P^(1/2) aufbauen (diagonal, sqrt der Diagonale)
sqrtP = sp.zeros(n_obs, n_obs)
for i in range(n_obs):
sqrtP[i, i] = sp.sqrt(P[i, i])
A_tilde = sqrtP * A # Ã
# M = (Ãᵀ Ã)^(-1)
M = (A_tilde.T * A_tilde).inv()
r_vec = sp.zeros(n_obs, 1)
for i in range(n_obs):
a_i = A_tilde.row(i) # 1 × n_param
a_i_row = sp.Matrix([a_i]) # explizit 1×n-Matrix
r_i = 1 - (a_i_row * M * a_i_row.T)[0, 0]
r_vec[i, 0] = r_i
return r_vec
def varianzkomponenten_schaetzung(
self,
v: sp.Matrix, # Residuenvektor (n × 1)
A: sp.Matrix, # Designmatrix
) -> Dict[int, float]:
if v.rows != self.n_obs:
raise ValueError("Länge von v passt nicht zur Anzahl Beobachtungen im Modell.")
# Aktuelle Gewichte
Q_ll, P = self.build_Qll_P()
# Redundanzzahlen pro Beobachtung
r_vec = self._redundanz_pro_beobachtung(A, P)
new_sigma0_sq: Dict[int, float] = {}
# Für jede Gruppe j:
unique_groups = sorted({int(g) for g in self.group_ids})
for g in unique_groups:
# Indizes der Beobachtungen in dieser Gruppe
idx = [i for i in range(self.n_obs) if int(self.group_ids[i, 0]) == g]
if not idx:
continue
# v_j, P_j, r_j extrahieren
v_j = sp.Matrix([v[i, 0] for i in idx]) # (m_j × 1)
P_j = sp.zeros(len(idx), len(idx))
r_j = 0
for ii, i in enumerate(idx):
P_j[ii, ii] = P[i, i]
r_j += r_vec[i, 0]
# σ̂_j^2 = (v_jᵀ P_j v_j) / r_j
sigma_hat_j_sq = (v_j.T * P_j * v_j)[0, 0] / r_j
# als float rausgeben, kann man aber auch symbolisch lassen
new_sigma0_sq[g] = float(sigma_hat_j_sq)
return new_sigma0_sq
def update_sigma0(self, new_sigma0_sq: Dict[int, float]) -> None:
for g, val in new_sigma0_sq.items():
self.sigma0_sq_groups[int(g)] = float(val)