Skip to content

Belief Propagation API

Status: Generated from current Python docstrings and type hints.

Inference backend surface for factor graphs, lowering, exact inference, junction tree inference, TRW-BP, Mean Field VI, and engine results.

gaia.engine.bp

BP v2 — belief propagation aligned with theory and Gaia IR.

Theory: docs/foundations/theory/06-factor-graphs.md, 07-belief-propagation.md IR lowering: docs/foundations/gaia-ir/07-lowering.md

CLI 主路径使用 InferenceEngine.run() 自动 dispatch: junction_tree → treewidth ≤ 20,精确 trw_bp → n ≤ 2000 且 treewidth > 20,有界近似 mean_field → n > 2000,大图快速近似

本模块下方的 infer() 是旧的便利函数,仍保留 loopy_bp 强制模式和 大图 loopy-BP fallback 以兼容旧调用;新代码需要和 gaia run infer 一致时, 应直接使用 InferenceEngine

BeliefPropagation

BeliefPropagation(damping: float = 0.5, max_iterations: int = 100, convergence_threshold: float = 1e-06)

Sum-product loopy Belief Propagation on a FactorGraph (v2).

Implements bp.md §3 exactly, with the following design principles: - All messages are 2-vectors [P(x=0), P(x=1)], always normalized. - Synchronous schedule: all new messages computed from old, then swapped. - Damping per bp.md §4 prevents oscillation in loopy graphs. - Relation variables (CONTRADICTION/EQUIVALENCE) participate fully. - BPDiagnostics always collected (full belief history).

damping: α in bp.md §4. Default 0.5. Range (0, 1]. 1.0 = fully replace old message (fast, may oscillate). 0.5 = half-step (default, balanced stability). Lower values increase stability but slow convergence. max_iterations: Upper bound on sweep iterations. convergence_threshold: Stop early when max|Δbelief| < threshold across all variables.

Initialize loopy BP with damping and convergence controls.

Parameters:

Name Type Description Default
damping float

Message damping factor in (0, 1].

0.5
max_iterations int

Maximum number of synchronous BP sweeps.

100
convergence_threshold float

Stop when the maximum belief change falls below this value.

1e-06

Raises:

Type Description
ValueError

If damping is outside (0, 1].

Source code in gaia/engine/bp/bp.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def __init__(
    self,
    damping: float = 0.5,
    max_iterations: int = 100,
    convergence_threshold: float = 1e-6,
) -> None:
    """Initialize loopy BP with damping and convergence controls.

    Args:
        damping: Message damping factor in ``(0, 1]``.
        max_iterations: Maximum number of synchronous BP sweeps.
        convergence_threshold: Stop when the maximum belief change falls below this value.

    Raises:
        ValueError: If ``damping`` is outside ``(0, 1]``.
    """
    if not (0.0 < damping <= 1.0):
        raise ValueError(f"damping must be in (0, 1], got {damping}")
    self._damping = damping
    self._max_iter = max_iterations
    self._threshold = convergence_threshold

run

run(graph: FactorGraph) -> BPResult

Run loopy BP on graph and return beliefs + diagnostics.

Always returns a BPResult with full diagnostics (never None).

graph: A validated FactorGraph. Variables referenced by factors must be registered. Cromwell clamping is enforced at graph construction.

Returns:

Type Description
BPResult

A BPResult containing posterior P(x=1) beliefs and full run diagnostics.

Source code in gaia/engine/bp/bp.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def run(self, graph: FactorGraph) -> BPResult:
    """Run loopy BP on *graph* and return beliefs + diagnostics.

    Always returns a BPResult with full diagnostics (never None).

    Args:
    graph:
        A validated FactorGraph. Variables referenced by factors must
        be registered. Cromwell clamping is enforced at graph construction.

    Returns:
        A BPResult containing posterior ``P(x=1)`` beliefs and full run diagnostics.
    """
    diag = BPDiagnostics()

    # --- Edge case: empty graph ---
    if not graph.variables:
        diag.converged = True
        return BPResult(beliefs={}, diagnostics=diag)

    # --- Edge case: no factors — beliefs = unary factors or neutral measure ---
    if not graph.factors:
        diag.converged = True
        initial_beliefs = _unfactored_beliefs(graph)
        for vid, p in initial_beliefs.items():
            diag.belief_history[vid] = [p]
        return BPResult(beliefs=initial_beliefs, diagnostics=diag)

    # --- Build reverse index: var -> list of factor indices ---
    var_to_factors = graph.get_var_to_factors()

    # --- Initialize unary factors as 2-vectors ---
    priors = _graph_prior_messages(graph)

    # --- Initialize all messages to uniform [0.5, 0.5] ---
    # f2v_msgs[(fi, vid)] = message from factor fi to variable vid
    # v2f_msgs[(vid, fi)] = message from variable vid to factor fi
    f2v_msgs, v2f_msgs = _initial_message_maps(graph)

    # --- Compute initial beliefs from unary factors only ---
    prev_beliefs = _initialize_belief_history(graph, diag)

    max_change = 0.0

    # --- Main BP loop ---
    for iteration in range(self._max_iter):
        # Step 1: Compute all variable→factor messages (synchronous)
        new_v2f = _compute_all_v2f(v2f_msgs, priors, var_to_factors, f2v_msgs)

        # Step 2: Compute all factor→variable messages (synchronous)
        new_f2v = _compute_all_f2v(graph, f2v_msgs, new_v2f)

        # Step 3: Damp and normalize both sets of messages
        _damp_f2v_messages(f2v_msgs, new_f2v, self._damping)
        _damp_v2f_messages(v2f_msgs, new_v2f, self._damping)

        # Step 4: Compute beliefs
        beliefs = _compute_beliefs(graph, priors, var_to_factors, f2v_msgs, diag)

        # Step 5: Check convergence
        max_change = max(abs(beliefs[vid] - prev_beliefs[vid]) for vid in beliefs)
        prev_beliefs = beliefs

        if max_change < self._threshold:
            _complete_diagnostics(
                diag, converged=True, iterations_run=iteration + 1, max_change=max_change
            )
            return BPResult(beliefs=beliefs, diagnostics=diag)

    # Did not converge within max_iterations
    _complete_diagnostics(
        diag, converged=False, iterations_run=self._max_iter, max_change=max_change
    )
    return BPResult(beliefs=prev_beliefs, diagnostics=diag)

EngineConfig dataclass

EngineConfig(jt_max_treewidth: int = JT_MAX_TREEWIDTH, mf_node_limit: int = MF_NODE_LIMIT, trw_damping: float = 0.5, trw_max_iter: int = 200, trw_threshold: float = 1e-08, mf_max_iter: int = 500, exact_max_vars: int = EXACT_MAX_VARS)

InferenceEngine 的配置参数。.

jt_max_treewidth: treewidth ≤ 此值时使用 JT(精确)。 mf_node_limit: 节点数 > 此值时使用 Mean Field VI。 trw_damping: TRW-BP 阻尼系数。 trw_max_iter: TRW-BP 最大迭代次数。 trw_threshold: TRW-BP 收敛阈值。 mf_max_iter: Mean Field 最大迭代次数。 exact_max_vars: 暴力枚举最大变量数。

InferenceEngine

InferenceEngine(config: EngineConfig | None = None)

统一推断引擎,自动选择最优算法。.

自动路由策略(method='auto'): 1. n > mf_node_limit → Mean Field VI(大图快速近似) 2. treewidth ≤ jt_max_treewidth → JT(精确) 3. 其他 → TRW-BP(有界近似)

config: EngineConfig,控制路由阈值和算法参数。

Initialize the inference engine with optional configuration.

Source code in gaia/engine/bp/engine.py
127
128
129
130
131
132
133
134
135
136
137
def __init__(self, config: EngineConfig | None = None) -> None:
    """Initialize the inference engine with optional configuration."""
    self._config = config or EngineConfig()
    cfg = self._config
    self._jt = JunctionTreeInference()
    self._trw = TRWBeliefPropagation(
        damping=cfg.trw_damping,
        max_iterations=cfg.trw_max_iter,
        convergence_threshold=cfg.trw_threshold,
    )
    self._mf = MeanFieldVI(max_iterations=cfg.mf_max_iter)

run

run(graph: FactorGraph, method: MethodChoice = 'auto') -> InferenceResult

在 graph 上运行推断。.

graph: 已 lower 好的 FactorGraph。 method: 'auto'(默认):按 n 和 treewidth 自动选择。 'jt':强制 JT(精确,treewidth ≤ 20)。 'trw_bp':强制 TRW-BP。 'mean_field':强制 Mean Field VI。 'exact':强制暴力枚举(仅适用于小图)。

Returns:

Type Description
InferenceResult

InferenceResult,包含边缘概率、算法元数据和耗时。

Source code in gaia/engine/bp/engine.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def run(
    self,
    graph: FactorGraph,
    method: MethodChoice = "auto",
) -> InferenceResult:
    """在 graph 上运行推断。.

    Args:
    graph:
        已 lower 好的 FactorGraph。
    method:
        'auto'(默认):按 n 和 treewidth 自动选择。
        'jt':强制 JT(精确,treewidth ≤ 20)。
        'trw_bp':强制 TRW-BP。
        'mean_field':强制 Mean Field VI。
        'exact':强制暴力枚举(仅适用于小图)。

    Returns:
        InferenceResult,包含边缘概率、算法元数据和耗时。
    """
    cfg = self._config
    t0 = time.perf_counter()
    result: TRWResult | MFResult

    if method == "exact":
        n = len(graph.variables)
        if n > cfg.exact_max_vars:
            raise ValueError(
                f"图有 {n} 个变量,超过暴力枚举上限 {cfg.exact_max_vars}。"
                "请使用 method='jt' 进行精确推断。"
            )
        beliefs, _Z = exact_inference(graph)
        diag = TRWDiagnostics()
        diag.converged = True
        for v, b in beliefs.items():
            diag.belief_history[v] = [b]
        result = TRWResult(beliefs=beliefs, diagnostics=diag)
        elapsed = (time.perf_counter() - t0) * 1000
        logger.info("InferenceEngine: exact, %d vars, %.1fms", n, elapsed)
        return InferenceResult(
            result=result,
            method_used="exact",
            treewidth=-1,
            elapsed_ms=elapsed,
            is_exact=True,
        )

    if method == "auto":
        n = len(graph.variables)
        if n > cfg.mf_node_limit:
            warnings.warn(
                "Mean Field VI fallback "
                f"(n > {cfg.mf_node_limit}) for {n} variables. "
                "This large-graph path is approximate and not production-grade; "
                "use method='trw_bp' when belief values need higher accuracy.",
                UserWarning,
                stacklevel=2,
            )
            method = "mean_field"
        else:
            tw = jt_treewidth(graph)
            method = "jt" if tw <= cfg.jt_max_treewidth else "trw_bp"

    if method == "jt":
        tw = jt_treewidth(graph)
        result = self._jt.run(graph)
        elapsed = (time.perf_counter() - t0) * 1000
        logger.info("InferenceEngine: JT (exact), treewidth=%d, %.1fms", tw, elapsed)
        return InferenceResult(
            result=result,
            method_used="jt",
            treewidth=tw,
            elapsed_ms=elapsed,
            is_exact=True,
        )

    if method == "trw_bp":
        tw = jt_treewidth(graph) if len(graph.variables) <= cfg.mf_node_limit else -1
        result = self._trw.run(graph)
        elapsed = (time.perf_counter() - t0) * 1000
        logger.info("InferenceEngine: TRW-BP, treewidth=%d, %.1fms", tw, elapsed)
        return InferenceResult(
            result=result,
            method_used="trw_bp",
            treewidth=tw,
            elapsed_ms=elapsed,
            is_exact=False,
        )

    if method == "mean_field":
        result = self._mf.run(graph)
        elapsed = (time.perf_counter() - t0) * 1000
        logger.info(
            "InferenceEngine: Mean Field, %d vars, %.1fms", len(graph.variables), elapsed
        )
        return InferenceResult(
            result=result,
            method_used="mean_field",
            treewidth=-1,
            elapsed_ms=elapsed,
            is_exact=False,
        )

    raise ValueError(
        f"method 必须是 'auto', 'jt', 'trw_bp', 'mean_field', 或 'exact';收到 {method!r}"
    )

benchmark

benchmark(graph: FactorGraph) -> dict[str, dict[str, object]]

运行所有可行算法并返回对比结果。.

Source code in gaia/engine/bp/engine.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def benchmark(self, graph: FactorGraph) -> dict[str, dict[str, object]]:
    """运行所有可行算法并返回对比结果。."""
    results: dict[str, dict[str, object]] = {}
    for m in ("jt", "trw_bp", "mean_field"):
        r = self.run(graph, method=m)
        results[m] = {
            "beliefs": r.beliefs,
            "elapsed_ms": r.elapsed_ms,
            "is_exact": r.is_exact,
            "treewidth": r.treewidth,
        }
    if len(graph.variables) <= self._config.exact_max_vars:
        r = self.run(graph, method="exact")
        results["exact"] = {
            "beliefs": r.beliefs,
            "elapsed_ms": r.elapsed_ms,
            "is_exact": True,
            "treewidth": -1,
        }
    return results

InferenceResult dataclass

InferenceResult(result: TRWResult | MFResult, method_used: str = 'unknown', treewidth: int = -1, elapsed_ms: float = 0.0, is_exact: bool = False)

InferenceEngine 的返回值,包含推断结果和算法元数据。.

result: 底层算法的结果(TRWResult 或 MFResult)。 method_used: 实际使用的算法:'jt', 'trw_bp', 'mean_field', 或 'exact'。 treewidth: 因子图的估计树宽(未计算时为 -1)。 elapsed_ms: 推断耗时(毫秒)。 is_exact: True 表示算法保证返回精确边缘概率。

beliefs property

beliefs: dict[str, float]

快捷访问 beliefs 字典。.

diagnostics property

diagnostics: TRWDiagnostics | MFDiagnostics

快捷访问 diagnostics。.

Factor dataclass

Factor(factor_id: str, factor_type: FactorType, variables: list[str], conclusion: str, p1: float | None = None, p2: float | None = None, cpt: tuple[float, ...] | None = None)

Factor in a factor graph with variables and potential function.

all_vars property

all_vars: list[str]

Return all variables involved in this factor.

FactorGraph

FactorGraph()

Factor graph for probabilistic inference.

Initialize an empty factor graph.

Source code in gaia/engine/bp/factor_graph.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(self) -> None:
    """Initialize an empty factor graph."""
    self.variables: dict[str, float] = {}
    self.unary_factors: dict[str, float] = {}
    self.hard_evidence: dict[str, int] = {}
    self.factors: list[Factor] = []
    # V8 audit trail: every class-II likelihood update appends a record
    # {"prior_before", "likelihood_ratio", "prior_after"} so the
    # full II→IV chain is recoverable from the graph alone (Gaia
    # auditability requirement; does not affect inference).
    self.posterior_evidence: dict[str, list[dict[str, float]]] = {}
    # V9 audit trail: D2 structural deduplications performed during
    # lowering. Each entry: {"op", "args", "conclusion", "dropped_count"}.
    # Populated by lowering, untouched by inference.
    self.dedup_audit: list[dict[str, object]] = []

add_variable

add_variable(var_id: str, prior: float | None = None) -> None

Register a binary variable, optionally with an explicit unary factor.

variables records the neutral display/initial measure for every variable. Only unary_factors is a Jaynes-style class IV soft prior (Cromwell ε permitted). Class I logical assertions belong in hard_evidence via :meth:add_evidence — those install a Cromwell- clamped {ε, 1-ε} strong prior (Gaia\'s adjusted Jaynes semantics), not a strict δ; downstream BP treats hard-evidence variables as pinned but still Bayes-updatable.

Source code in gaia/engine/bp/factor_graph.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def add_variable(self, var_id: str, prior: float | None = None) -> None:
    r"""Register a binary variable, optionally with an explicit unary factor.

    ``variables`` records the neutral display/initial measure for every
    variable. Only ``unary_factors`` is a Jaynes-style class IV soft prior
    (Cromwell ε permitted). Class I logical assertions belong in
    ``hard_evidence`` via :meth:`add_evidence` — those install a Cromwell-
    clamped {ε, 1-ε} strong prior (Gaia\'s adjusted Jaynes semantics), not
    a strict δ; downstream BP treats hard-evidence variables as pinned but
    still Bayes-updatable.
    """
    if prior is None:
        self.variables.setdefault(var_id, 0.5)
        return
    clamped = _cromwell_clamp(prior, label=f"variable '{var_id}' unary")
    if var_id in self.hard_evidence:
        target = float(self.hard_evidence[var_id])
        if abs(clamped - target) > 0.5:
            raise ValueError(
                f"Variable '{var_id}': soft prior {clamped:g} contradicts "
                f"hard evidence={self.hard_evidence[var_id]} (D5)."
            )
        return
    if var_id in self.unary_factors:
        existing = self.unary_factors[var_id]
        if abs(existing - clamped) > CROMWELL_EPS:
            raise ValueError(
                f"Variable '{var_id}': conflicting unary priors "
                f"existing={existing:g}, new={clamped:g} (D1 violation)."
            )
    self.variables[var_id] = clamped
    self.unary_factors[var_id] = clamped

add_evidence

add_evidence(var_id: str, value: int) -> None

Class I hard observation with Cromwell clamp.

Gaia adjusts Jaynes: hard evidence is stored as a very strong soft prior {ε, 1-ε} (ε = CROMWELL_EPS = 1e-3), not as strict δ {0, 1}. This preserves Bayesian updatability (Cromwell's rule) and prevents log(0) pathologies in BP message passing, at the cost of a small O(ε) systematic bias vs. strict Jaynes Class I semantics.

Source code in gaia/engine/bp/factor_graph.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def add_evidence(self, var_id: str, value: int) -> None:
    """Class I hard observation with Cromwell clamp.

    Gaia adjusts Jaynes: hard evidence is stored as a very strong soft
    prior {ε, 1-ε} (ε = CROMWELL_EPS = 1e-3), not as strict δ {0, 1}.
    This preserves Bayesian updatability (Cromwell's rule) and prevents
    log(0) pathologies in BP message passing, at the cost of a small
    O(ε) systematic bias vs. strict Jaynes Class I semantics.
    """
    if var_id not in self.variables:
        raise KeyError(f"Variable '{var_id}' not registered.")
    if value not in (0, 1):
        raise ValueError(f"add_evidence() value must be 0 or 1, got {value}.")
    if var_id in self.hard_evidence and self.hard_evidence[var_id] != value:
        raise ValueError(
            f"Variable '{var_id}': conflicting hard evidence "
            f"{self.hard_evidence[var_id]} vs {value} (D5 violation)."
        )
    if var_id in self.unary_factors:
        existing = self.unary_factors[var_id]
        if (value == 1 and existing < 0.5) or (value == 0 and existing > 0.5):
            raise ValueError(
                f"Variable '{var_id}': hard evidence={value} contradicts "
                f"existing soft prior {existing:g} (D5 violation)."
            )
        self.unary_factors.pop(var_id, None)
    self.hard_evidence[var_id] = value
    self.variables[var_id] = (1.0 - CROMWELL_EPS) if value == 1 else CROMWELL_EPS

observe

observe(var_id: str, value: int) -> None

Hard evidence alias — delegates to :meth:add_evidence.

Source code in gaia/engine/bp/factor_graph.py
147
148
149
def observe(self, var_id: str, value: int) -> None:
    """Hard evidence alias — delegates to :meth:`add_evidence`."""
    self.add_evidence(var_id, value)

add_likelihood

add_likelihood(var_id: str, likelihood_ratio: float) -> None

Soft evidence (class II): fold likelihood ratio into the class-IV unary.

P_new(x=1) = normalize(π · lr, (1−π) · 1) where lr = P(E|x=1)/P(E|x=0). Records the update in posterior_evidence[var_id] for audit.

Source code in gaia/engine/bp/factor_graph.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def add_likelihood(
    self,
    var_id: str,
    likelihood_ratio: float,
) -> None:
    """Soft evidence (class II): fold likelihood ratio into the class-IV unary.

    P_new(x=1) = normalize(π · lr, (1−π) · 1) where lr = P(E|x=1)/P(E|x=0).
    Records the update in posterior_evidence[var_id] for audit.
    """
    if var_id not in self.variables:
        raise KeyError(f"Variable '{var_id}' not registered.")
    if likelihood_ratio <= 0:
        raise ValueError(f"likelihood_ratio must be > 0, got {likelihood_ratio}.")
    if var_id in self.hard_evidence:
        raise ValueError(
            f"Variable '{var_id}': cannot apply soft likelihood — variable is "
            f"already pinned by hard_evidence={self.hard_evidence[var_id]} (D5)."
        )
    pi = self.unary_factors.get(var_id, self.variables.get(var_id, 0.5))
    odds = pi / (1.0 - pi) * likelihood_ratio
    new_pi = odds / (1.0 + odds)
    clamped = _cromwell_clamp(new_pi, label=f"variable {var_id!r} likelihood-updated unary")
    self.variables[var_id] = clamped
    self.unary_factors[var_id] = clamped
    self.posterior_evidence.setdefault(var_id, []).append(
        {
            "prior_before": float(pi),
            "likelihood_ratio": float(likelihood_ratio),
            "prior_after": float(clamped),
        }
    )

add_factor

add_factor(factor_id: str, factor_type: FactorType, variables: Sequence[str], conclusion: str, *, p1: float | None = None, p2: float | None = None, cpt: Sequence[float] | None = None) -> None

Add a factor to the graph with specified type and variables.

Source code in gaia/engine/bp/factor_graph.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def add_factor(  # noqa: C901
    self,
    factor_id: str,
    factor_type: FactorType,
    variables: Sequence[str],
    conclusion: str,
    *,
    p1: float | None = None,
    p2: float | None = None,
    cpt: Sequence[float] | None = None,
) -> None:
    """Add a factor to the graph with specified type and variables."""
    v_list = list(variables)
    if conclusion in v_list:
        raise ValueError(
            f"Factor '{factor_id}': conclusion '{conclusion}' must not appear in variables."
        )

    ft = factor_type
    fp1: float | None = None
    fp2: float | None = None
    fcpt: tuple[float, ...] | None = None

    if ft in (
        FactorType.IMPLICATION,
        FactorType.NEGATION,
        FactorType.CONJUNCTION,
        FactorType.DISJUNCTION,
        FactorType.EQUIVALENCE,
        FactorType.CONTRADICTION,
        FactorType.COMPLEMENT,
    ):
        if p1 is not None or p2 is not None or cpt is not None:
            raise ValueError(f"Deterministic factor '{factor_id}' must not set p1/p2/cpt.")
        self._validate_deterministic(factor_id, ft, v_list)

    elif ft == FactorType.SOFT_ENTAILMENT:
        if cpt is not None:
            raise ValueError(f"SOFT_ENTAILMENT '{factor_id}' must not set cpt.")
        if len(v_list) != 1:
            raise ValueError(
                f"SOFT_ENTAILMENT '{factor_id}' requires exactly 1 premise variable, "
                f"got {len(v_list)}."
            )
        if p1 is None or p2 is None:
            raise ValueError(f"SOFT_ENTAILMENT '{factor_id}' requires p1 and p2.")
        p1c = _cromwell_clamp(p1, label=f"factor '{factor_id}' p1")
        p2c = _cromwell_clamp(p2, label=f"factor '{factor_id}' p2")
        if p1c + p2c <= 1.0:
            raise ValueError(
                f"SOFT_ENTAILMENT '{factor_id}' requires p1 + p2 > 1 "
                f"(after Cromwell clamp got {p1c + p2c})."
            )
        fp1, fp2 = p1c, p2c

    elif ft == FactorType.CONDITIONAL:
        if p1 is not None or p2 is not None:
            raise ValueError(f"CONDITIONAL '{factor_id}' must not set p1/p2.")
        if not v_list:
            raise ValueError(
                f"CONDITIONAL '{factor_id}' requires at least one premise variable."
            )
        if cpt is None:
            raise ValueError(f"CONDITIONAL '{factor_id}' requires cpt.")
        expected = 1 << len(v_list)
        fcpt = tuple(_cromwell_clamp(float(x), label=f"cpt[{i}]") for i, x in enumerate(cpt))
        if len(fcpt) != expected:
            raise ValueError(
                f"CONDITIONAL '{factor_id}': cpt length must be 2^k = {expected}, "
                f"got {len(fcpt)}."
            )

    elif ft == FactorType.PAIRWISE_POTENTIAL:
        if p1 is not None or p2 is not None:
            raise ValueError(f"PAIRWISE_POTENTIAL '{factor_id}' must not set p1/p2.")
        if len(v_list) != 1:
            raise ValueError(
                f"PAIRWISE_POTENTIAL '{factor_id}' requires exactly 1 variable plus "
                f"the paired conclusion variable, got {len(v_list)} variables."
            )
        if cpt is None:
            raise ValueError(f"PAIRWISE_POTENTIAL '{factor_id}' requires cpt.")
        fcpt = tuple(float(x) for x in cpt)
        if len(fcpt) != 4:
            raise ValueError(
                f"PAIRWISE_POTENTIAL '{factor_id}': cpt length must be 4, got {len(fcpt)}."
            )
        if any((not isfinite(x)) or x < 0.0 for x in fcpt):
            raise ValueError(
                f"PAIRWISE_POTENTIAL '{factor_id}' requires finite non-negative weights."
            )
        if sum(fcpt) <= 0.0:
            raise ValueError(
                f"PAIRWISE_POTENTIAL '{factor_id}' requires at least one positive weight."
            )
    else:
        raise ValueError(f"Unknown FactorType: {ft!r}")

    self.factors.append(
        Factor(
            factor_id=factor_id,
            factor_type=factor_type,
            variables=v_list,
            conclusion=conclusion,
            p1=fp1,
            p2=fp2,
            cpt=fcpt,
        )
    )

get_var_to_factors

get_var_to_factors() -> dict[str, list[int]]

Return mapping from variable names to factor indices.

Source code in gaia/engine/bp/factor_graph.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def get_var_to_factors(self) -> dict[str, list[int]]:
    """Return mapping from variable names to factor indices."""
    index: dict[str, list[int]] = {vid: [] for vid in self.variables}
    for fi, factor in enumerate(self.factors):
        for vid in factor.all_vars:
            if vid in index:
                index[vid].append(fi)
            else:
                logger.warning(
                    "Factor '%s' references undeclared variable '%s'.",
                    factor.factor_id,
                    vid,
                )
    return index

validate

validate() -> list[str]

Validate the factor graph and return list of errors.

Source code in gaia/engine/bp/factor_graph.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def validate(self) -> list[str]:
    """Validate the factor graph and return list of errors."""
    errors: list[str] = []
    for fi, factor in enumerate(self.factors):
        seen: set[str] = set()
        for vid in factor.all_vars:
            if vid not in self.variables:
                errors.append(
                    f"Factor[{fi}] '{factor.factor_id}': variable '{vid}' not registered."
                )
            if vid in seen:
                errors.append(
                    f"Factor[{fi}] '{factor.factor_id}': "
                    f"variable '{vid}' appears more than once in all_vars."
                )
            seen.add(vid)
    return errors

summary

summary() -> str

Generate summary string of the factor graph.

Source code in gaia/engine/bp/factor_graph.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def summary(self) -> str:
    """Generate summary string of the factor graph."""
    lines = [f"FactorGraph: {len(self.variables)} variables, {len(self.factors)} factors"]
    lines.append("Variables:")
    for vid, measure in sorted(self.variables.items()):
        unary = self.unary_factors.get(vid)
        if unary is None:
            lines.append(f"  {vid:30s}  latent_measure={measure:.4f}")
        else:
            lines.append(f"  {vid:30s}  unary={unary:.4f}")
    lines.append("Factors:")
    for factor in self.factors:
        extra = ""
        if factor.p1 is not None and factor.p2 is not None:
            extra = f"  p1={factor.p1:.4f}  p2={factor.p2:.4f}"
        if factor.cpt is not None:
            extra += f"  cpt_len={len(factor.cpt)}"
        lines.append(
            f"  [{factor.factor_type.name:18s}] {factor.factor_id}"
            f"  variables={factor.variables}  conclusion={factor.conclusion}{extra}"
        )
    return "\n".join(lines)

FactorType

Bases: Enum

Enumeration of factor types in the factor graph.

JointDistribution

Bases: BaseModel

A normalized joint table over a binary variable set.

JointQueryUnavailable

Bases: BaseModel

A method-specific joint query miss.

JointQueryUnavailableError

JointQueryUnavailableError(method: JointQueryMethod, variables: Sequence[str], reason: str, diagnostics: dict[str, Any] | None = None)

Bases: RuntimeError

Raised when a method cannot provide a requested joint distribution.

Initialize an unavailable joint-query error.

Source code in gaia/engine/bp/joint_query.py
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    method: JointQueryMethod,
    variables: Sequence[str],
    reason: str,
    diagnostics: dict[str, Any] | None = None,
) -> None:
    """Initialize an unavailable joint-query error."""
    super().__init__(reason)
    self.method = method
    self.variables = list(variables)
    self.reason = reason
    self.diagnostics = diagnostics or {}

JunctionTreeInference

Exact inference via the Junction Tree Algorithm.

Converts the FactorGraph to a Junction Tree (chordal graph with clique potentials), then runs exact two-pass message passing (Shafer-Shenoy collect + distribute). The result is mathematically identical to brute-force enumeration but runs in O(n * 2^w) time.

This fixes loopy BP's double-counting error on graphs with short cycles. For Gaia's factor graphs (treewidth ≤ ~15), this is the preferred engine.

Returns the same TRWResult interface as BeliefPropagation for drop-in use.

run

run(graph: FactorGraph) -> TRWResult

Run exact Junction Tree inference on graph.

Parameters

graph: A validated FactorGraph. All variables referenced by factors must be registered.

Returns:

Type Description
TRWResult

TRWResult containing exact marginal P(v=1) beliefs and

TRWResult

diagnostics recording treewidth and clique count.

Source code in gaia/engine/bp/junction_tree.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
def run(self, graph: FactorGraph) -> TRWResult:
    """Run exact Junction Tree inference on *graph*.

    Parameters
    ----------
    graph:
        A validated FactorGraph. All variables referenced by factors
        must be registered.

    Returns:
        TRWResult containing exact marginal ``P(v=1)`` beliefs and
        diagnostics recording treewidth and clique count.
    """
    diag = TRWDiagnostics()

    if not graph.variables:
        diag.converged = True
        return TRWResult(beliefs={}, diagnostics=diag)

    if not graph.factors:
        # No factors: beliefs are explicit unary factors or neutral MaxEnt.
        diag.converged = True
        diag.treewidth = 0

        # Priority: hard_evidence > unary_factors > neutral MaxEnt
        def _belief0(vid: str) -> float:
            if vid in graph.hard_evidence:
                return (1.0 - CROMWELL_EPS) if graph.hard_evidence[vid] == 1 else CROMWELL_EPS
            return graph.unary_factors.get(vid, 0.5)

        beliefs = {vid: _belief0(vid) for vid in graph.variables}
        for vid, p in beliefs.items():
            diag.belief_history[vid] = [p]
        return TRWResult(beliefs=beliefs, diagnostics=diag)

    calibration = calibrate_junction_tree(graph)
    diag.treewidth = calibration.treewidth
    diag.iterations_run = 2
    logger.debug(
        "JT: %d variables, %d cliques, treewidth=%d",
        len(graph.variables),
        len(calibration.cliques),
        calibration.treewidth,
    )
    beliefs = _extract_beliefs(
        calibration.cliques,
        calibration.calibrated,
        calibration.clique_var_lists,
        set(graph.variables.keys()),
    )

    # NOTE: Do NOT apply Cromwell clamping to the output beliefs.
    # Cromwell's rule applies to author-supplied inputs (priors, p values)
    # to prevent zero probabilities from locking out future evidence.
    # Computed posterior beliefs are allowed to be near 0 or 1 — that is
    # the correct answer when evidence overwhelmingly supports or refutes
    # a proposition.

    # Record beliefs in history (single "iteration" = 0)
    for vid, b in beliefs.items():
        diag.belief_history[vid] = [b]
    diag.converged = True
    diag.max_change_at_stop = 0.0

    return TRWResult(beliefs=beliefs, diagnostics=diag)

MeanFieldVI

MeanFieldVI(max_iterations: int = 500, convergence_threshold: float = 1e-06, track_elbo: bool = False)

Coordinate Ascent Variational Inference (CAVI) for binary factor graphs.

Scales to large graphs (n > 2000) where Junction Tree and TRW-BP are too expensive. Complexity O(n * F * 2^k) per sweep.

Parameters

max_iterations: Maximum number of full CAVI sweeps. convergence_threshold: Stop when max|delta_mu| < threshold. track_elbo: If True, compute and record ELBO after each sweep (adds O(F*2^k) cost).

Initialize mean field inference state.

Source code in gaia/engine/bp/mean_field.py
216
217
218
219
220
221
222
223
224
225
def __init__(
    self,
    max_iterations: int = 500,
    convergence_threshold: float = 1e-6,
    track_elbo: bool = False,
) -> None:
    """Initialize mean field inference state."""
    self._max_iter = max_iterations
    self._threshold = convergence_threshold
    self._track_elbo = track_elbo

run

run(graph: FactorGraph) -> MFResult

Run CAVI on graph and return beliefs + diagnostics.

Source code in gaia/engine/bp/mean_field.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def run(self, graph: FactorGraph) -> MFResult:
    """Run CAVI on graph and return beliefs + diagnostics."""
    diag = MFDiagnostics()

    if not graph.variables:
        return MFResult(beliefs={}, diagnostics=diag)

    var_to_factors = graph.get_var_to_factors()

    # Initialise mu: hard_evidence -> Cromwell-clamped {ε, 1-ε},
    # others -> prior or 0.5
    mu: dict[str, float] = {}
    for vid in graph.variables:
        if vid in graph.hard_evidence:
            mu[vid] = (1.0 - CROMWELL_EPS) if graph.hard_evidence[vid] == 1 else CROMWELL_EPS
        elif vid in graph.unary_factors:
            mu[vid] = _clamp(graph.unary_factors[vid])
        else:
            mu[vid] = 0.5

    # Soft variables (updated by CAVI)
    soft_vars = [v for v in graph.variables if v not in graph.hard_evidence]

    # Seed belief history
    for vid in graph.variables:
        diag.belief_history[vid] = [mu[vid]]

    if self._track_elbo:
        diag.elbo_history.append(_compute_elbo(graph, mu, var_to_factors))

    max_change = 0.0

    for iteration in range(self._max_iter):
        max_change = 0.0

        for vid in soft_vars:
            old_mu = mu[vid]
            mu[vid] = _cavi_update(vid, graph, mu, var_to_factors)
            max_change = max(max_change, abs(mu[vid] - old_mu))

        for vid in graph.variables:
            diag.belief_history[vid].append(mu[vid])

        if self._track_elbo:
            diag.elbo_history.append(_compute_elbo(graph, mu, var_to_factors))

        if max_change < self._threshold:
            diag.converged = True
            diag.iterations_run = iteration + 1
            diag.max_change_at_stop = max_change
            return MFResult(beliefs=dict(mu), diagnostics=diag)

    diag.converged = False
    diag.iterations_run = self._max_iter
    diag.max_change_at_stop = max_change
    return MFResult(beliefs=dict(mu), diagnostics=diag)

TRWBeliefPropagation

TRWBeliefPropagation(damping: float = 0.5, max_iterations: int = 200, convergence_threshold: float = 1e-06, schedule: str = 'synchronous')

Tree-Reweighted Belief Propagation (Wainwright et al. 2003/2005).

Replaces loopy BP as the default approximate inference algorithm. Uses factor-level reweighting for higher-order factor graphs.

Parameters

damping: Message mixing coefficient alpha in (0, 1]. Default 0.5. max_iterations: Maximum number of full sweeps. convergence_threshold: Stop when max|delta_belief| < threshold. schedule: "synchronous" -- standard parallel sweep (default). "residual" -- currently rejected; residual TRW-BP is not yet stable.

Initialize TRW-BP oscillation diagnostic state.

Source code in gaia/engine/bp/trw_bp.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def __init__(
    self,
    damping: float = 0.5,
    max_iterations: int = 200,
    convergence_threshold: float = 1e-6,
    schedule: str = "synchronous",
) -> None:
    """Initialize TRW-BP oscillation diagnostic state."""
    if not (0.0 < damping <= 1.0):
        raise ValueError(f"damping must be in (0, 1], got {damping}")
    if schedule not in ("synchronous",):
        raise ValueError(  # pragma: no cover
            f"schedule must be 'synchronous', got {schedule!r}. "
            f"Residual schedule for TRW-BP is not yet stable."
        )
    self._damping = damping
    self._max_iter = max_iterations
    self._threshold = convergence_threshold
    self._schedule = schedule

run

run(graph: FactorGraph) -> TRWResult

Run TRW-BP on graph and return beliefs + diagnostics.

Source code in gaia/engine/bp/trw_bp.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
def run(self, graph: FactorGraph) -> TRWResult:  # noqa: C901
    """Run TRW-BP on graph and return beliefs + diagnostics."""
    diag = TRWDiagnostics()

    if not graph.variables:
        diag.converged = True
        return TRWResult(beliefs={}, diagnostics=diag)

    if not graph.factors:
        beliefs = {}
        for vid in graph.variables:
            if vid in graph.hard_evidence:
                beliefs[vid] = (
                    (1.0 - CROMWELL_EPS) if graph.hard_evidence[vid] == 1 else CROMWELL_EPS
                )
            else:
                beliefs[vid] = graph.unary_factors.get(vid, 0.5)
        for vid, b in beliefs.items():
            diag.belief_history[vid] = [b]
        diag.converged = True
        diag.factor_joint_tables = []
        return TRWResult(beliefs=beliefs, diagnostics=diag)

    var_to_factors = graph.get_var_to_factors()
    rho = _compute_factor_weights(graph, var_to_factors)
    if rho:
        diag.rho = (
            next(v for v in rho.values() if v < 1.0)
            if any(v < 1.0 for v in rho.values())
            else 1.0
        )

    def _prior_for(vid: str) -> Msg:
        if vid in graph.hard_evidence:
            v = graph.hard_evidence[vid]
            # Cromwell-clamped {ε, 1-ε} per Gaia's adjusted Jaynes Class I
            if v == 0:
                return np.array([1.0 - CROMWELL_EPS, CROMWELL_EPS], dtype=np.float64)
            return np.array([CROMWELL_EPS, 1.0 - CROMWELL_EPS], dtype=np.float64)
        if vid in graph.unary_factors:
            return _prior_to_msg(graph.unary_factors[vid])
        return _uniform_msg()

    priors: dict[str, Msg] = {vid: _prior_for(vid) for vid in graph.variables}

    f2v_msgs: dict[tuple[int, str], Msg] = {}
    v2f_msgs: dict[tuple[str, int], Msg] = {}
    for fi, factor in enumerate(graph.factors):
        for vid in factor.all_vars:
            if vid in graph.variables:
                f2v_msgs[(fi, vid)] = _uniform_msg()
                v2f_msgs[(vid, fi)] = _uniform_msg()

    prev_beliefs: dict[str, float] = {}
    for vid in graph.variables:
        if vid in graph.hard_evidence:
            pi = (1.0 - CROMWELL_EPS) if graph.hard_evidence[vid] == 1 else CROMWELL_EPS
        else:
            pi = graph.unary_factors.get(vid, 0.5)
        prev_beliefs[vid] = pi
        diag.belief_history[vid] = [pi]

    if self._schedule == "synchronous":
        return self._run_synchronous(
            graph, diag, priors, var_to_factors, f2v_msgs, v2f_msgs, prev_beliefs, rho
        )
    return self._run_residual(
        graph, diag, priors, var_to_factors, f2v_msgs, v2f_msgs, prev_beliefs, rho
    )

exact_inference

exact_inference(graph: FactorGraph) -> tuple[dict[str, float], float]

Compute exact marginal beliefs via enumeration over joint distribution.

Source code in gaia/engine/bp/exact.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def exact_inference(graph: FactorGraph) -> tuple[dict[str, float], float]:
    """Compute exact marginal beliefs via enumeration over joint distribution."""
    var_ids, _, all_log_joints = _enumerate_log_joint(graph)
    joint, z_shifted = _shifted_joint(all_log_joints)
    log_Z = all_log_joints.max() + np.log(z_shifted)
    Z = float(np.exp(log_Z))

    full_arange = np.arange(len(all_log_joints), dtype=np.int64)
    beliefs: dict[str, float] = {}
    for i, vid in enumerate(var_ids):
        mask = ((full_arange >> i) & 1) == 1
        beliefs[vid] = float(joint[mask].sum() / z_shifted)

    return beliefs, Z

exact_joint_over

exact_joint_over(graph: FactorGraph, free_vars: list[str]) -> np.ndarray

Return the normalized joint over free_vars by exact enumeration.

The result is indexed by the bit pattern over free_vars in order: index sum(v_i << i for i, v_i in enumerate(free_vars)).

Source code in gaia/engine/bp/exact.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def exact_joint_over(graph: FactorGraph, free_vars: list[str]) -> np.ndarray:
    """Return the normalized joint over ``free_vars`` by exact enumeration.

    The result is indexed by the bit pattern over ``free_vars`` in order:
    index ``sum(v_i << i for i, v_i in enumerate(free_vars))``.
    """
    if not free_vars:
        return np.array([1.0], dtype=np.float64)

    _var_ids, var_idx, all_log_joints = _enumerate_log_joint(graph)
    missing = [v for v in free_vars if v not in var_idx]
    if missing:
        raise KeyError(f"exact_joint_over: unknown free vars {missing!r}")

    joint, z_shifted = _shifted_joint(all_log_joints)
    full_arange = np.arange(len(all_log_joints), dtype=np.int64)
    assignment_idx = np.zeros(len(all_log_joints), dtype=np.int64)
    for bit, vid in enumerate(free_vars):
        assignment_idx |= ((full_arange >> var_idx[vid]) & 1).astype(np.int64) << bit

    probs = np.bincount(assignment_idx, weights=joint, minlength=1 << len(free_vars))
    return probs / z_shifted

compare_joint_over

compare_joint_over(graph: FactorGraph, variables: Sequence[str], *, methods: Sequence[JointQueryMethod] = ('exact', 'junction_tree', 'trw_bp', 'mean_field')) -> list[JointDistribution | JointQueryUnavailable]

Run several joint providers and collect unavailable methods explicitly.

Source code in gaia/engine/bp/joint_query.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def compare_joint_over(
    graph: FactorGraph,
    variables: Sequence[str],
    *,
    methods: Sequence[JointQueryMethod] = ("exact", "junction_tree", "trw_bp", "mean_field"),
) -> list[JointDistribution | JointQueryUnavailable]:
    """Run several joint providers and collect unavailable methods explicitly."""
    requested = _normalized_variables(variables)
    results: list[JointDistribution | JointQueryUnavailable] = []
    for method in methods:
        try:
            results.append(joint_over(graph, requested, method=method))
        except JointQueryUnavailableError as error:
            results.append(
                JointQueryUnavailable(
                    variables=error.variables,
                    method=error.method,
                    reason=error.reason,
                    diagnostics=error.diagnostics,
                )
            )
    return results

joint_over

joint_over(graph: FactorGraph, variables: Sequence[str], *, method: JointQueryMethod) -> JointDistribution

Return a joint table over variables using one inference method.

Source code in gaia/engine/bp/joint_query.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def joint_over(
    graph: FactorGraph,
    variables: Sequence[str],
    *,
    method: JointQueryMethod,
) -> JointDistribution:
    """Return a joint table over ``variables`` using one inference method."""
    requested = _normalized_variables(variables)
    _require_known_variables(graph, requested, method)

    if method == "exact":
        return _exact_joint_over(graph, requested)
    if method == "mean_field":
        return _mean_field_joint_over(graph, requested)
    if method == "junction_tree":
        return _junction_tree_joint_over(graph, requested)
    if method == "trw_bp":
        return _trw_bp_joint_over(graph, requested)

    raise ValueError(f"Unknown joint query method: {method!r}")

jt_treewidth

jt_treewidth(graph: FactorGraph) -> int

Estimate the treewidth of the factor graph via min-fill triangulation.

Returns the size of the largest maximal clique minus 1.

Source code in gaia/engine/bp/junction_tree.py
655
656
657
658
659
660
661
662
663
664
665
def jt_treewidth(graph: FactorGraph) -> int:
    """Estimate the treewidth of the factor graph via min-fill triangulation.

    Returns the size of the largest maximal clique minus 1.
    """
    if not graph.variables:
        return 0
    moral_adj = _build_moral_graph(graph)
    _, elim_cliques = _triangulate_min_fill(moral_adj)
    max_cliques = _maximal_cliques(elim_cliques)
    return max(len(c) for c in max_cliques) - 1

lower_local_graph

lower_local_graph(canonical: LocalCanonicalGraph, *, node_priors: dict[str, float] | None = None, strategy_conditional_params: dict[str, list[float]] | None = None, expand_formal: bool = True, infer_use_degraded_noisy_and: bool = False, review_manifest: ReviewManifest | None = None) -> FactorGraph

Build a FactorGraph from a local canonical Gaia IR graph.

Parameters

canonical: Local graph with knowledges, operators, strategies. node_priors: Optional prior P(claim=1) per Knowledge id (claim nodes only). strategy_conditional_params: Maps strategy_id -> conditional_probabilities list (infer: 2^k entries, noisy_and: 1 entry). expand_formal: If True, expand FormalStrategy to deterministic factors. If False, raises NotImplementedError; folded FormalStrategy lowering is a future backend path. infer_use_degraded_noisy_and: If True, lower infer with CONJUNCTION+SOFT_ENTAILMENT using only all-true / all-false CPT entries (information loss for general CPT). review_manifest: Optional qualitative ReviewManifest. When present, v6 action-backed strategies/operators are lowered only after their latest review is accepted. Legacy IR targets without metadata.action_label are not gated.

Source code in gaia/engine/bp/lowering.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def lower_local_graph(
    canonical: LocalCanonicalGraph,
    *,
    node_priors: dict[str, float] | None = None,
    strategy_conditional_params: dict[str, list[float]] | None = None,
    expand_formal: bool = True,
    infer_use_degraded_noisy_and: bool = False,
    review_manifest: ReviewManifest | None = None,
) -> FactorGraph:
    """Build a FactorGraph from a local canonical Gaia IR graph.

    Parameters
    ----------
    canonical:
        Local graph with knowledges, operators, strategies.
    node_priors:
        Optional prior P(claim=1) per Knowledge id (claim nodes only).
    strategy_conditional_params:
        Maps strategy_id -> conditional_probabilities list (infer: 2^k entries,
        noisy_and: 1 entry).
    expand_formal:
        If True, expand FormalStrategy to deterministic factors. If False,
        raises NotImplementedError; folded FormalStrategy lowering is a
        future backend path.
    infer_use_degraded_noisy_and:
        If True, lower ``infer`` with CONJUNCTION+SOFT_ENTAILMENT using only
        all-true / all-false CPT entries (information loss for general CPT).
    review_manifest:
        Optional qualitative ReviewManifest. When present, v6 action-backed
        strategies/operators are lowered only after their latest review is
        accepted. Legacy IR targets without ``metadata.action_label`` are not
        gated.
    """
    priors = node_priors or {}
    no_user_prior_ids, expression_helper_ids = _helper_prior_filter_ids(canonical)
    if no_user_prior_ids:
        priors = {k: v for k, v in priors.items() if k not in no_user_prior_ids}
    metadata_priors = _metadata_priors(canonical, expression_helper_ids)
    strat_params = strategy_conditional_params or {}
    fg = FactorGraph()
    ctr = [0]

    lowerable_operators = _review_allowed_operators(canonical, review_manifest)
    lowerable_operators = _dedup_operators(
        lowerable_operators,
        dedup_audit=fg.dedup_audit,
        context="graph_operators",
    )
    claim_ids = _add_claim_variables(
        fg,
        canonical,
        priors=priors,
        expression_helper_ids=expression_helper_ids,
        relation_concl_ids=_relation_conclusion_ids(lowerable_operators),
    )

    strat_by_id = {s.strategy_id: s for s in canonical.strategies if s.strategy_id}

    _lower_operators(
        fg,
        lowerable_operators,
        priors=priors,
        claim_ids=claim_ids,
        expression_helper_ids=expression_helper_ids,
        ctr=ctr,
    )
    _lower_graph_strategies(
        fg,
        canonical,
        strat_by_id=strat_by_id,
        priors=priors,
        strat_params=strat_params,
        metadata_priors=metadata_priors,
        expand_formal=expand_formal,
        infer_degraded=infer_use_degraded_noisy_and,
        ctr=ctr,
        claim_ids=claim_ids,
        review_manifest=review_manifest,
    )

    return fg

merge_factor_graphs

merge_factor_graphs(local_fg: FactorGraph, dep_graphs: list[tuple[str, FactorGraph, str]], *, local_prefix: str) -> FactorGraph

Merge local and dependency factor graphs for joint inference.

Parameters

local_fg: The local package's factor graph. dep_graphs: List of (dep_import_name, dep_factor_graph, dep_qid_prefix) triples. dep_qid_prefix identifies variables owned by that dependency, e.g. "github:dep_pkg::". local_prefix: QID prefix for the local package, e.g. "github:my_pkg::". Variables starting with this prefix are owned by the local package.

Returns:

Type Description
FactorGraph

A merged :class:FactorGraph where shared QIDs map to a single

FactorGraph

variable (dep-owned prior takes precedence for dep nodes) and all

FactorGraph

factors coexist with prefixed IDs to avoid collision.

Source code in gaia/engine/bp/lowering.py
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
def merge_factor_graphs(  # noqa: C901
    local_fg: FactorGraph,
    dep_graphs: list[tuple[str, FactorGraph, str]],
    *,
    local_prefix: str,
) -> FactorGraph:
    """Merge local and dependency factor graphs for joint inference.

    Parameters
    ----------
    local_fg:
        The local package's factor graph.
    dep_graphs:
        List of ``(dep_import_name, dep_factor_graph, dep_qid_prefix)``
        triples. ``dep_qid_prefix`` identifies variables owned by that
        dependency, e.g. ``"github:dep_pkg::"``.
    local_prefix:
        QID prefix for the local package, e.g. ``"github:my_pkg::"``.
        Variables starting with this prefix are owned by the local package.

    Returns:
        A merged :class:`FactorGraph` where shared QIDs map to a single
        variable (dep-owned prior takes precedence for dep nodes) and all
        factors coexist with prefixed IDs to avoid collision.
    """
    merged = FactorGraph()

    def _copy_variable(source: FactorGraph, var_id: str, *, force: bool = False) -> None:
        if var_id in source.unary_factors:
            prior = source.unary_factors[var_id]
            if force or var_id not in merged.unary_factors:
                merged.variables[var_id] = prior
                merged.unary_factors[var_id] = prior
        else:
            if force or var_id not in merged.variables:
                merged.variables[var_id] = source.variables.get(var_id, 0.5)
                merged.unary_factors.pop(var_id, None)

    # 1. Add dep variables first. Owner dep is authoritative; non-owner references
    # are placeholders that must not overwrite the owner prior.
    for _dep_name, dep_fg, dep_prefix in dep_graphs:
        for var_id in dep_fg.variables:
            _copy_variable(dep_fg, var_id, force=var_id.startswith(dep_prefix))

    # 2. Add local variables — overwrite only for locally-owned nodes
    for var_id in local_fg.variables:
        if var_id.startswith(local_prefix):
            # Local owns this node — always use local prior
            _copy_variable(local_fg, var_id, force=True)
        elif var_id not in merged.variables:
            # New variable only seen locally (e.g. intermediate _m_ vars)
            _copy_variable(local_fg, var_id)
        # else: dep owns it, dep prior already set — skip

    # 3. Copy dep factors with prefixed IDs
    for dep_name, dep_fg, _dep_prefix in dep_graphs:
        for factor in dep_fg.factors:
            prefixed = replace(factor, factor_id=f"dep_{dep_name}_{factor.factor_id}")
            merged.factors.append(prefixed)

    # 4. Copy local factors with prefix
    for factor in local_fg.factors:
        prefixed = replace(factor, factor_id=f"local_{factor.factor_id}")
        merged.factors.append(prefixed)

    return merged

infer

infer(graph: FactorGraph, method: str = 'auto') -> dict[str, float]

Legacy convenience wrapper: infer FactorGraph marginals.

Prefer :class:InferenceEngine for new code and CLI-parity behavior.

Parameters

graph: 已 lower 好的 FactorGraph。 method: "auto" — 按 treewidth / n 自动选择算法 "junction_tree" — 强制 JT(精确,treewidth ≤ 20) "trw_bp" — 强制 TRW-BP "loopy_bp" — legacy force Loopy BP "mean_field" — force Mean Field VI

Returns:

dict[str, float] 变量 ID → P(x=1) 的边缘概率。

Source code in gaia/engine/bp/__init__.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def infer(
    graph: FactorGraph,
    method: str = "auto",
) -> dict[str, float]:
    """Legacy convenience wrapper: infer FactorGraph marginals.

    Prefer :class:`InferenceEngine` for new code and CLI-parity behavior.

    Parameters
    ----------
    graph:
        已 lower 好的 FactorGraph。
    method:
        "auto"        — 按 treewidth / n 自动选择算法
        "junction_tree" — 强制 JT(精确,treewidth ≤ 20)
        "trw_bp"      — 强制 TRW-BP
        "loopy_bp"    — legacy force Loopy BP
        "mean_field"  — force Mean Field VI

    Returns:
    -------
    dict[str, float]
        变量 ID → P(x=1) 的边缘概率。
    """
    if method == "auto":
        n = len(graph.variables)
        if n > _LOOPY_BP_NODE_LIMIT:
            # Legacy convenience fallback. The CLI's InferenceEngine routes
            # n > 2000 to Mean Field VI instead.
            method = "loopy_bp"
        else:
            tw = jt_treewidth(graph)
            method = "junction_tree" if tw <= _JT_TREEWIDTH_LIMIT else "trw_bp"

    result: TRWResult | MFResult | BPResult

    if method == "junction_tree":
        jt = JunctionTreeInference()
        result = jt.run(graph)
        return result.beliefs

    if method == "trw_bp":
        trw = TRWBeliefPropagation()
        result = trw.run(graph)
        return result.beliefs

    if method == "loopy_bp":
        bp = BeliefPropagation(damping=0.5, max_iterations=500, convergence_threshold=1e-6)
        result = bp.run(graph)
        return result.beliefs

    if method == "mean_field":
        mf = MeanFieldVI()
        result = mf.run(graph)
        return result.beliefs

    raise ValueError(
        f"method must be auto, junction_tree, trw_bp, loopy_bp, or mean_field; got {method!r}"
    )