Discrete Energy-Based Models¤
This module contains implementations of discrete energy-based models.
thrml.models.DiscreteEBMFactor(thrml.models.EBMFactor, thrml.WeightedFactor)
¤
Implements batches of energy function terms of the form s_1 * ... * s_M * W[c_1, ..., c_N], where the s_i are spin variables and the c_i are categorical variables.
No variable should show up twice in any given interaction. If this happens, the result of sampling from a model that includes the bad factor might not agree with the Boltzmann distribution. For example, the interaction w * s_1 * s_1 * s_2 would violate this rule because s_1 shows up twice. To allow you to do something weird if you want to, this condition has not been enforced in the code.
Attributes:
spin_node_groups: the node groups involved in the batch of factors that represent spin-valued random variables.categorical_node_groups: the node groups involved in the batch of factors that represent categorical-valued random variables.weights: the batch of weight tensors W associated with the factors we are implementing.weightsshould have leading dimension b, where b is number of nodes in each element ofspin_node_groupsandcategorical_node_groups. This tensor has shape [b, x_1, ..., x_N] where b is the number of nodes in each block and N is the length ofcategorical_node_groups.is_spin: a map that indicates if a given node type represents a spin-valued random variable or not.
energy(global_state: list[Array], block_spec: thrml.BlockSpec)
¤
Compute the energy associated with this factor.
In this case, that is the sum of terms like s_1 * ... * s_M * W[c_1, ..., c_N].
to_interaction_groups() -> list[thrml.InteractionGroup]
¤
Produce interaction groups that implement this factor.
In this case, we have to treat the spin and categorical node groups slightly differently.
thrml.models.DiscreteEBMInteraction
¤
An interaction that shows up when sampling from discrete-variable EBMs.
Attributes:
n_spin: the number of spin states involved in the interaction.weights: the weight tensor associated with this interaction.
thrml.models.SquareDiscreteEBMFactor(thrml.models.DiscreteEBMFactor)
¤
A discrete factor with a square interaction weight tensor (shape [b, x, x, ..., x]).
If a discrete factor is square, the interaction groups corresponding to different choices of the head node blocks can be merged. This could yield smaller XLA programs and improved runtime performance via more efficient use of accelerators.
to_interaction_groups() -> list[thrml.InteractionGroup]
¤
Call the parent class to_interaction_groups, and merge the results.
thrml.models.SpinEBMFactor(thrml.models.SquareDiscreteEBMFactor)
¤
A DiscreteEBMFactor that involves only spin variables.
thrml.models.CategoricalEBMFactor(thrml.models.DiscreteEBMFactor)
¤
A DiscreteEBMFactor that involves only categorical variables.
thrml.models.SquareCategoricalEBMFactor(thrml.models.SquareDiscreteEBMFactor)
¤
A DiscreteEBMFactor that involves only categorical variables that also has a square weight tensor.
thrml.models.SpinGibbsConditional(thrml.BernoulliConditional)
¤
A conditional update for spin-valued random variables that will perform a Gibbs sampling update given one or
more DiscreteEBMInteractions.
This function can be extended to handle a broader class of interactions via inheritance. Specifically, a
child class can override the compute_parameters method defined here, compute contributions to \(\gamma\)
from other types of interactions, and then call this method to take into account the contributions from
DiscreteEBMInteractions.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']]], sampler_state: None, output_sd: PyTree[jax._src.core.ShapeDtypeStruct]) -> PyTree
¤
Compute the parameter \(\gamma\) of a spin-valued Bernoulli distribution given DiscreteEBMInteractions:
where the sum over \(i\) is over all the DiscreteEBMInteractions seen by this function.
thrml.models.CategoricalGibbsConditional(thrml.SoftmaxConditional)
¤
A conditional update for categorical random variables that will perform a Gibbs sampling update given one or
more DiscreteEBMInteractions.
This function can be extended to handle other interactions in the same way as thrml.models.SpinGibbsConditional.
Attributes:
n_categories: how many categories are involved in the softmax distribution this sampler will sample from.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']]], sampler_state: None, output_sd: PyTree[jax._src.core.ShapeDtypeStruct]) -> PyTree
¤
Compute the parameter \(\theta\) of a softmax distribution given DiscreteEBMInteractions:
where the sum over \(i\) is over all the DiscreteEBMInteractions seen by this function.