THRML

Discrete energy-based models

Discrete EBM building blocks for spin and categorical variables, with square-tensor specializations and their matching Gibbs conditionals.

DiscreteEBMFactorclass
DiscreteEBMFactor(spin_node_groups: list[Block], categorical_node_groups: list[Block], weights: Array)

Implements batches of energy function terms of the form s_1 * ... * s_M * W[c_1, ..., c_N], where the s_i are spin variables and the c_i are categorical variables.

No variable should show up twice in any given interaction. If this happens, the result of sampling from a model that includes the bad factor might not agree with the Boltzmann distribution. For example, the interaction w * s_1 * s_1 * s_2 would violate this rule because s_1 shows up twice. To allow you to do something weird if you want to, this condition has not been enforced in the code.

Attributes:

  • spin_node_groups: the node groups involved in the batch of factors that represent spin-valued random variables.
  • categorical_node_groups: the node groups involved in the batch of factors that represent categorical-valued random variables.
  • weights: the batch of weight tensors W associated with the factors we are implementing. weights should have leading dimension b, where b is number of nodes in each element of spin_node_groups and categorical_node_groups. This tensor has shape [b, x_1, ..., x_N] where b is the number of nodes in each block and N is the length of categorical_node_groups.
  • is_spin: a map that indicates if a given node type represents a spin-valued random variable or not.
DiscreteEBMInteractionclass
DiscreteEBMInteraction(n_spin: int, weights: Array)

An interaction that shows up when sampling from discrete-variable EBMs.

Attributes:

  • n_spin: the number of spin states involved in the interaction.
  • weights: the weight tensor associated with this interaction.
SquareDiscreteEBMFactorclass
SquareDiscreteEBMFactor(spin_node_groups: list[Block], categorical_node_groups: list[Block], weights: Array)

A discrete factor with a square interaction weight tensor (shape [b, x, x, ..., x]).

If a discrete factor is square, the interaction groups corresponding to different choices of the head node blocks can be merged. This could yield smaller XLA programs and improved runtime performance via more efficient use of accelerators.

SpinEBMFactorclass
SpinEBMFactor(node_groups: list[Block], weights: Array)

A DiscreteEBMFactor that involves only spin variables.

CategoricalEBMFactorclass
CategoricalEBMFactor(node_groups: list[Block], weights: Array)

A DiscreteEBMFactor that involves only categorical variables.

SquareCategoricalEBMFactorclass
SquareCategoricalEBMFactor(node_groups: list[Block], weights: Array)

A DiscreteEBMFactor that involves only categorical variables that also has a square weight tensor.

SpinGibbsConditionalclass
SpinGibbsConditional()

A conditional update for spin-valued random variables that will perform a Gibbs sampling update given one or more DiscreteEBMInteractions.

This function can be extended to handle a broader class of interactions via inheritance. Specifically, a child class can override the compute_parameters method defined here, compute contributions to $\gamma$ from other types of interactions, and then call this method to take into account the contributions from DiscreteEBMInteractions.

CategoricalGibbsConditionalclass
CategoricalGibbsConditional(n_categories: int)

A conditional update for categorical random variables that will perform a Gibbs sampling update given one or more DiscreteEBMInteractions.

This function can be extended to handle other interactions in the same way as [thrml.models.SpinGibbsConditional][].

Attributes:

  • n_categories: how many categories are involved in the softmax distribution this sampler will sample from.