# Factor Graphs
## Overview
Factor graphs are a type of [[graphical_models|graphical model]] that represents factorizations of functions, particularly probability distributions. They provide a unified framework for [[inference_algorithms|inference algorithms]] and bridge the gap between [[directed_graphical_models|directed]] and [[undirected_graphical_models|undirected]] graphical models.
## Mathematical Foundation
### 1. Basic Structure
A factor graph G = (V, F, E) consists of:
- V: Variable nodes representing [[random_variables|random variables]]
- F: Factor nodes encoding [[probability_distributions|probability distributions]]
- E: Edges representing [[probabilistic_dependencies|probabilistic dependencies]]
### 2. Factorization
For a probability distribution p(x):
```math
p(x_1, ..., x_n) = \prod_{a \in F} f_a(x_{\partial a})
```
where:
- f_a are [[factor_functions|factor functions]]
- ∂a denotes variables connected to factor a
### 3. [[bayesian_factorization|Bayesian Factorization]]
In Bayesian terms:
```math
p(x, θ | y) ∝ p(y | x, θ)p(x | θ)p(θ)
```
where:
- p(y | x, θ) is the [[likelihood_function|likelihood]]
- p(x | θ) is the [[prior_distribution|prior]]
- p(θ) is the [[hyperprior|hyperprior]]
## Components and Structure
### 1. [[variable_nodes|Variable Nodes]]
#### Types and Properties
```julia
struct VariableNode{T}
id::Symbol
domain::Domain{T}
neighbors::Set{FactorNode}
messages::Dict{FactorNode, Message}
belief::Distribution{T}
end
```
#### Categories
- **Observable Variables**
- Represent data points
- Fixed during inference
- Drive belief updates
- **Latent Variables**
- Hidden states
- Model parameters
- Inferred quantities
- **Parameter Nodes**
- [[hyperparameters|Hyperparameters]]
- [[model_parameters|Model parameters]]
- [[sufficient_statistics|Sufficient statistics]]
### 2. [[factor_nodes|Factor Nodes]]
#### Base Implementation
```julia
abstract type FactorNode end
struct ProbabilisticFactor <: FactorNode
distribution::Distribution
variables::Vector{VariableNode}
parameters::Dict{Symbol, Any}
end
struct DeterministicFactor <: FactorNode
function::Function
inputs::Vector{VariableNode}
outputs::Vector{VariableNode}
end
struct ConstraintFactor <: FactorNode
constraint::Function
variables::Vector{VariableNode}
tolerance::Float64
end
```
#### Message Computation
```julia
function compute_message(factor::FactorNode, to::VariableNode)
# Collect incoming messages
messages = [msg for (node, msg) in factor.messages if node != to]
# Compute outgoing message
if isa(factor, ProbabilisticFactor)
return compute_probabilistic_message(factor, messages, to)
elseif isa(factor, DeterministicFactor)
return compute_deterministic_message(factor, messages, to)
else
return compute_constraint_message(factor, messages, to)
end
end
```
### 3. [[edges|Edges]]
#### Properties
```julia
struct Edge
source::Union{VariableNode, FactorNode}
target::Union{VariableNode, FactorNode}
message_type::Type{<:Message}
parameters::Dict{Symbol, Any}
end
```
#### Message Types
- Forward messages (variable to factor)
- Backward messages (factor to variable)
- Parameter messages
- Constraint messages
## Message Passing and Inference
### 1. [[belief_propagation|Belief Propagation]]
#### Forward Messages
```math
μ_{x→f}(x) = \prod_{g \in N(x) \backslash f} μ_{g→x}(x)
```
#### Backward Messages
```math
μ_{f→x}(x) = \sum_{x_{\partial f \backslash x}} f(x_{\partial f}) \prod_{y \in \partial f \backslash x} μ_{y→f}(y)
```
#### Implementation
```julia
function belief_propagation!(graph::FactorGraph; max_iters=100)
for iter in 1:max_iters
# Update variable to factor messages
for var in graph.variables
for factor in var.neighbors
message = compute_var_to_factor_message(var, factor)
update_message!(var, factor, message)
end
end
# Update factor to variable messages
for factor in graph.factors
for var in factor.variables
message = compute_factor_to_var_message(factor, var)
update_message!(factor, var, message)
end
end
# Check convergence
if check_convergence(graph)
break
end
end
end
```
### 2. [[variational_message_passing|Variational Message Passing]]
#### ELBO Optimization
```math
\mathcal{L}(q) = \mathbb{E}_q[\log p(x,z)] - \mathbb{E}_q[\log q(z)]
```
#### Natural Gradient Updates
```math
θ_t = θ_{t-1} + η\nabla_{\text{nat}}\mathcal{L}(q)
```
#### Implementation
```julia
function variational_message_passing!(graph::FactorGraph;
learning_rate=0.01,
max_iters=100)
for iter in 1:max_iters
# Compute natural gradients
gradients = compute_natural_gradients(graph)
# Update variational parameters
for (node, grad) in gradients
update_parameters!(node, grad, learning_rate)
end
# Update messages
update_messages!(graph)
# Check ELBO convergence
if check_elbo_convergence(graph)
break
end
end
end
```
### 3. [[expectation_propagation|Expectation Propagation]]
#### Moment Matching
```math
\text{minimize}_{q_i} \text{KL}(p||q_1...q_i...q_n)
```
#### Implementation
```julia
function expectation_propagation!(graph::FactorGraph; max_iters=100)
for iter in 1:max_iters
# Update approximate factors
for factor in graph.factors
# Compute cavity distribution
cavity = compute_cavity_distribution(factor)
# Moment matching
new_approx = moment_match(cavity, factor)
# Update approximation
update_approximation!(factor, new_approx)
end
# Check convergence
if check_convergence(graph)
break
end
end
end
```
## Advanced Topics
### 1. [[structured_factor_graphs|Structured Factor Graphs]]
#### Temporal Structure
```julia
@model function temporal_factor_graph(T)
# State variables
x = Vector{VariableNode}(undef, T)
# Temporal factors
for t in 2:T
@factor f[t] begin
x[t] ~ transition(x[t-1])
end
end
return x
end
```
#### Hierarchical Structure
```julia
@model function hierarchical_factor_graph()
# Global parameters
θ_global ~ prior()
# Local parameters
θ_local = Vector{VariableNode}(undef, N)
for i in 1:N
θ_local[i] ~ conditional(θ_global)
end
return θ_local
end
```
### 2. [[continuous_variables|Continuous Variables]]
#### Gaussian Messages
```julia
struct GaussianMessage <: Message
mean::Vector{Float64}
precision::Matrix{Float64}
function GaussianMessage(μ, Λ)
@assert size(μ, 1) == size(Λ, 1) == size(Λ, 2)
new(μ, Λ)
end
end
function multiply_messages(m1::GaussianMessage, m2::GaussianMessage)
Λ = m1.precision + m2.precision
μ = Λ \ (m1.precision * m1.mean + m2.precision * m2.mean)
return GaussianMessage(μ, Λ)
end
```
### 3. [[convergence_properties|Convergence Properties]]
#### Fixed Point Conditions
```math
b^*(x) = \frac{1}{Z} \prod_{f \in N(x)} μ^*_{f→x}(x)
```
#### Bethe Free Energy
```math
F_{\text{Bethe}} = \sum_i F_i + \sum_a F_a
```
## Implementation
### 1. Graph Construction
```julia
struct FactorGraph
variables::Set{VariableNode}
factors::Set{FactorNode}
edges::Set{Edge}
function FactorGraph()
new(Set{VariableNode}(), Set{FactorNode}(), Set{Edge}())
end
end
function add_variable!(graph::FactorGraph, var::VariableNode)
push!(graph.variables, var)
end
function add_factor!(graph::FactorGraph, factor::FactorNode)
push!(graph.factors, factor)
for var in factor.variables
add_edge!(graph, Edge(var, factor))
end
end
```
### 2. Message Scheduling
```julia
struct MessageSchedule
order::Vector{Tuple{Union{VariableNode,FactorNode},
Union{VariableNode,FactorNode}}}
priorities::Vector{Float64}
end
function schedule_messages(graph::FactorGraph)
schedule = MessageSchedule()
# Forward pass
for level in graph.levels
for node in level
schedule_forward_messages!(schedule, node)
end
end
# Backward pass
for level in reverse(graph.levels)
for node in level
schedule_backward_messages!(schedule, node)
end
end
return schedule
end
```
### 3. Inference Execution
```julia
function run_inference(graph::FactorGraph;
method=:belief_propagation,
max_iters=100)
if method == :belief_propagation
belief_propagation!(graph, max_iters=max_iters)
elseif method == :variational
variational_message_passing!(graph, max_iters=max_iters)
elseif method == :expectation_propagation
expectation_propagation!(graph, max_iters=max_iters)
else
error("Unknown inference method: $method")
end
return compute_beliefs(graph)
end
```
## Applications
### 1. [[bayesian_inference|Bayesian Inference]]
- Parameter estimation
- Model selection
- Uncertainty quantification
### 2. [[probabilistic_programming|Probabilistic Programming]]
- Model specification
- Automatic inference
- Compositional modeling
### 3. [[active_inference|Active Inference]]
- Policy selection
- Perception-action loops
- Free energy minimization
## Best Practices
### 1. Design Patterns
- Modular factor construction
- Reusable message computations
- Efficient graph structures
### 2. Numerical Considerations
- Message normalization
- Numerical stability
- Convergence monitoring
### 3. Testing and Validation
- Unit tests for factors
- Message validity checks
- End-to-end inference tests
## Integration with Bayesian Networks
### 1. [[bayesian_network_conversion|Bayesian Network Conversion]]
```julia
function from_bayesian_network(bn::BayesianNetwork)
# Create factor graph
fg = FactorGraph()
# Add variables
for node in bn.nodes
add_variable!(fg, VariableNode(node))
end
# Add CPT factors
for (var, cpt) in bn.parameters
add_cpt_factor!(fg, var, cpt)
end
return fg
end
```
### 2. [[inference_equivalence|Inference Equivalence]]
- Message passing in trees
- Loopy belief propagation
- Variational approximations
### 3. [[model_comparison|Model Comparison]]
- Structure comparison
- Parameter learning
- Performance metrics
## RxInfer Integration
### 1. Reactive Message Passing
```julia
using RxInfer
@model function rxinfer_factor_graph(data)
# Prior distributions
θ ~ GaussianMeanPrecision(0.0, 1.0)
σ ~ GammaShape(1.0, 1.0)
# Likelihood factors
for i in 1:length(data)
y[i] ~ GaussianMeanPrecision(θ, σ)
end
end
# Create inference algorithm
algorithm = ReactiveMP.messagePassingAlgorithm(
model = rxinfer_factor_graph,
data = observed_data
)
# Execute inference
results = ReactiveMP.infer(algorithm)
```
### 2. Streaming Factor Graphs
```julia
struct StreamingFactorGraph
base_graph::FactorGraph
stream_nodes::Vector{StreamNode}
buffer_size::Int
function StreamingFactorGraph(model, buffer_size=1000)
graph = create_base_graph(model)
stream_nodes = initialize_stream_nodes(model)
new(graph, stream_nodes, buffer_size)
end
end
function process_stream!(graph::StreamingFactorGraph, data_stream)
for batch in data_stream
# Update streaming nodes
update_stream_nodes!(graph, batch)
# Perform message passing
propagate_messages!(graph)
# Update beliefs
update_beliefs!(graph)
# Prune old messages if needed
prune_old_messages!(graph)
end
end
```
### 3. Reactive Message Types
```julia
abstract type ReactiveMessage end
struct GaussianReactiveMessage <: ReactiveMessage
mean::Observable{Float64}
precision::Observable{Float64}
function GaussianReactiveMessage(μ::Observable, τ::Observable)
new(μ, τ)
end
end
struct StreamingMessage <: ReactiveMessage
distribution::Observable{Distribution}
timestamp::Observable{Float64}
function StreamingMessage(dist::Observable, ts::Observable)
new(dist, ts)
end
end
```
## Advanced Bayesian Integration
### 1. Hierarchical Bayesian Models
```mermaid
graph TD
subgraph "Hierarchical Factor Structure"
A[Global Parameters] --> B[Group Parameters]
B --> C[Individual Parameters]
C --> D[Observations]
E[Hyperpriors] --> A
end
style A fill:#f9f,stroke:#333
style C fill:#bbf,stroke:#333
style E fill:#bfb,stroke:#333
```
### 2. Conjugate Factor Pairs
```julia
struct ConjugateFactor{T<:Distribution} <: FactorNode
prior::T
likelihood::Function
posterior::T
sufficient_stats::Dict{Symbol, Any}
function ConjugateFactor(prior::T, likelihood::Function) where T
posterior = deepcopy(prior)
stats = initialize_sufficient_stats(prior)
new{T}(prior, likelihood, posterior, stats)
end
end
function update_conjugate_factor!(factor::ConjugateFactor, data)
# Update sufficient statistics
update_stats!(factor.sufficient_stats, data)
# Compute posterior parameters
posterior_params = compute_posterior_params(
factor.prior, factor.sufficient_stats)
# Update posterior
update_posterior!(factor.posterior, posterior_params)
end
```
### 3. Non-parametric Extensions
```julia
struct DirichletProcessFactor <: FactorNode
base_measure::Distribution
concentration::Float64
clusters::Vector{Cluster}
assignments::Dict{Int, Int}
function DirichletProcessFactor(base::Distribution, α::Float64)
new(base, α, Cluster[], Dict())
end
end
function sample_assignment(factor::DirichletProcessFactor, data_point)
# Compute cluster probabilities
probs = compute_cluster_probabilities(factor, data_point)
# Sample new assignment
assignment = sample_categorical(probs)
# Update clusters if necessary
if assignment > length(factor.clusters)
create_new_cluster!(factor, data_point)
end
return assignment
end
```
## Advanced Message Passing Schemes
### 1. Stochastic Message Passing
```julia
struct StochasticMessagePassing
n_particles::Int
resampling_threshold::Float64
function StochasticMessagePassing(n_particles=1000, threshold=0.5)
new(n_particles, threshold)
end
end
function propagate_particles!(smp::StochasticMessagePassing,
factor::FactorNode,
particles::Vector{Particle})
# Propagate particles through factor
weights = zeros(smp.n_particles)
new_particles = similar(particles)
for i in 1:smp.n_particles
# Sample
new_particles[i] = propose_particle(factor, particles[i])
# Weight
weights[i] = compute_importance_weight(
factor, particles[i], new_particles[i])
end
# Resample if needed
if effective_sample_size(weights) < smp.resampling_threshold
new_particles = resample_particles(new_particles, weights)
end
return new_particles
end
```
### 2. Distributed Message Passing
```julia
struct DistributedFactorGraph
subgraphs::Vector{FactorGraph}
interfaces::Dict{Tuple{Int,Int}, Interface}
function DistributedFactorGraph(graph::FactorGraph, n_partitions)
# Partition graph
subgraphs = partition_graph(graph, n_partitions)
# Create interfaces
interfaces = create_interfaces(subgraphs)
new(subgraphs, interfaces)
end
end
function distributed_inference!(graph::DistributedFactorGraph)
# Initialize workers
workers = [Worker(subgraph) for subgraph in graph.subgraphs]
while !converged(workers)
# Local inference
@sync for worker in workers
@async local_inference!(worker)
end
# Exchange messages
exchange_interface_messages!(graph)
# Update beliefs
update_worker_beliefs!(workers)
end
end
```
### 3. Adaptive Message Scheduling
```julia
struct AdaptiveScheduler
priority_queue::PriorityQueue{Message, Float64}
residual_threshold::Float64
function AdaptiveScheduler(threshold=1e-6)
new(PriorityQueue{Message, Float64}(), threshold)
end
end
function schedule_message!(scheduler::AdaptiveScheduler,
message::Message,
residual::Float64)
if residual > scheduler.residual_threshold
enqueue!(scheduler.priority_queue, message => residual)
end
end
function process_messages!(scheduler::AdaptiveScheduler)
while !isempty(scheduler.priority_queue)
# Get highest priority message
message = dequeue!(scheduler.priority_queue)
# Update message
new_message = update_message!(message)
# Compute residual and reschedule if needed
residual = compute_residual(message, new_message)
schedule_message!(scheduler, new_message, residual)
end
end
```
## Integration with Active Inference
### 1. Free Energy Minimization
```julia
struct FreeEnergyFactor <: FactorNode
internal_states::Vector{VariableNode}
external_states::Vector{VariableNode}
precision::Matrix{Float64}
function FreeEnergyFactor(internal, external, precision)
new(internal, external, precision)
end
end
function compute_free_energy(factor::FreeEnergyFactor)
# Compute prediction error
error = compute_prediction_error(
factor.internal_states,
factor.external_states
)
# Weight by precision
weighted_error = error' * factor.precision * error
# Add complexity penalty
complexity = compute_complexity_term(factor.internal_states)
return 0.5 * weighted_error + complexity
end
```
### 2. Policy Selection
```julia
struct PolicyFactor <: FactorNode
policies::Vector{Policy}
expected_free_energy::Vector{Float64}
temperature::Float64
function PolicyFactor(policies, temperature=1.0)
n_policies = length(policies)
new(policies, zeros(n_policies), temperature)
end
end
function select_policy(factor::PolicyFactor)
# Compute softmax probabilities
probs = softmax(-factor.temperature * factor.expected_free_energy)
# Sample policy
policy_idx = sample_categorical(probs)
return factor.policies[policy_idx]
end
```
## Performance Optimization
### 1. Message Caching
```julia
struct MessageCache
storage::Dict{Tuple{FactorNode, VariableNode}, Message}
max_size::Int
eviction_policy::Symbol
function MessageCache(max_size=10000, policy=:lru)
new(Dict(), max_size, policy)
end
end
function cache_message!(cache::MessageCache,
factor::FactorNode,
variable::VariableNode,
message::Message)
key = (factor, variable)
# Evict if needed
if length(cache.storage) >= cache.max_size
evict_message!(cache)
end
# Store message
cache.storage[key] = message
end
```
### 2. Parallel Message Updates
```julia
function parallel_message_passing!(graph::FactorGraph)
# Group independent messages
message_groups = group_independent_messages(graph)
# Update messages in parallel
@sync for group in message_groups
@async begin
for message in group
update_message!(message)
end
end
end
end
```
### 3. GPU Acceleration
```julia
struct GPUFactorGraph
variables::CuArray{VariableNode}
factors::CuArray{FactorNode}
messages::CuArray{Message}
function GPUFactorGraph(graph::FactorGraph)
# Transfer to GPU
variables = cu(collect(graph.variables))
factors = cu(collect(graph.factors))
messages = cu(collect_messages(graph))
new(variables, factors, messages)
end
end
function gpu_message_passing!(graph::GPUFactorGraph)
# Kernel for parallel message updates
@cuda threads=256 blocks=div(length(graph.messages), 256) do
update_messages_kernel(graph.messages)
end
synchronize()
end
```
## References
1. Kschischang, F. R., et al. (2001). Factor Graphs and the Sum-Product Algorithm
2. Wainwright, M. J., & Jordan, M. I. (2008). Graphical Models, Exponential Families, and Variational Inference
3. Loeliger, H. A. (2004). An Introduction to Factor Graphs
4. Bishop, C. M. (2006). Pattern Recognition and Machine Learning
5. Koller, D., & Friedman, N. (2009). Probabilistic Graphical Models