THRML

Energy-based models

Energy-based models define a distribution through an energy function. THRML factorizes that energy so block Gibbs can sample it.

AbstractEBMclass
AbstractEBM()

Something that has a well-defined energy function (map from a state to a scalar).

AbstractFactorizedEBMclass
AbstractFactorizedEBM(node_shape_dtypes: Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]] = {<class 'SpinNode'>: ShapeDtypeStruct(shape=(), dtype=bool), <class 'CategoricalNode'>: ShapeDtypeStruct(shape=(), dtype=uint8)})

An EBM that is made up of Factors, i.e., an EBM with an energy function like,

$$\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)$$

where the sum over $i$ is taken over factors.

Child classes must define a property which returns a list of factors that substantiate the EBM.

Attributes:

  • node_shape_dtypes: the shape/dtypes of the nodes involved in this EBM. Used to generate the BlockSpec that defines the global state that factors receive to compute energy.
FactorizedEBMclass
FactorizedEBM(factors: list[EBMFactor], node_shape_dtypes: Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]] = {<class 'SpinNode'>: ShapeDtypeStruct(shape=(), dtype=bool), <class 'CategoricalNode'>: ShapeDtypeStruct(shape=(), dtype=uint8)})

An EBM that is defined by a concrete list of factors.

Attributes:

  • _factors: the list of factors that defines this EBM.
EBMFactorclass
EBMFactor(node_groups: list[Block])

A factor that defines an energy function.