Energy-based models
Energy-based models define a distribution through an energy function. THRML factorizes that energy so block Gibbs can sample it.
AbstractEBMclassAbstractEBM()Something that has a well-defined energy function (map from a state to a scalar).
AbstractFactorizedEBMclassAbstractFactorizedEBM(node_shape_dtypes: Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]] = {<class 'SpinNode'>: ShapeDtypeStruct(shape=(), dtype=bool), <class 'CategoricalNode'>: ShapeDtypeStruct(shape=(), dtype=uint8)})An EBM that is made up of Factors, i.e., an EBM with an energy function like,
$$\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)$$
where the sum over $i$ is taken over factors.
Child classes must define a property which returns a list of factors that substantiate the EBM.
Attributes:
node_shape_dtypes: the shape/dtypes of the nodes involved in this EBM. Used to generate the BlockSpec that defines the global state that factors receive to compute energy.
FactorizedEBMclassFactorizedEBM(factors: list[EBMFactor], node_shape_dtypes: Mapping[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]] = {<class 'SpinNode'>: ShapeDtypeStruct(shape=(), dtype=bool), <class 'CategoricalNode'>: ShapeDtypeStruct(shape=(), dtype=uint8)})An EBM that is defined by a concrete list of factors.
Attributes:
_factors: the list of factors that defines this EBM.
EBMFactorclassEBMFactor(node_groups: list[Block])A factor that defines an energy function.