Energy-Based Models¤
This module contains implementations of energy-based models.
thrml.models.AbstractEBM
¤
Something that has a well-defined energy function (map from a state to a scalar).
energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], blocks: list[thrml.Block]) -> Float[Array, '']
¤
Evaluate the energy function of the EBM given some state information.
Arguments:
state: The state for which to evaluate the energy function. Must be compatible withblocks.blocks: Specifies how the information instateis organized.
Returns:
A scalar representing the energy value associated with state.
thrml.models.AbstractFactorizedEBM(thrml.models.AbstractEBM)
¤
An EBM that is made up of Factors, i.e., an EBM with an energy function like,
\[\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)\]
where the sum over \(i\) is taken over factors.
Child classes must define a property which returns a list of factors that substantiate the EBM.
Attributes:
node_shape_dtypes: the shape/dtypes of the nodes involved in this EBM. Used to generate the BlockSpec that defines the global state that factors receive to compute energy.
thrml.models.FactorizedEBM(thrml.models.AbstractFactorizedEBM)
¤
An EBM that is defined by a concrete list of factors.
Attributes:
_factors: the list of factors that defines this EBM.
__init__(factors: list[thrml.models.EBMFactor], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]] = {thrml.SpinNode: ShapeDtypeStruct(shape=(), dtype=bool), thrml.CategoricalNode: ShapeDtypeStruct(shape=(), dtype=uint8)})
¤
energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], blocks: list[thrml.Block]) -> Float[Array, '']
¤
thrml.models.EBMFactor(thrml.AbstractFactor)
¤
A factor that defines an energy function.
energy(global_state: list[Array], block_spec: thrml.BlockSpec) -> Float[Array, '']
¤
Evaluate the energy function of the factor.
Arguments:
global_state: The state information to use to evaluate the energy function. Is a global state ofblock_spec.block_spec: TheBlockSpecused to generateglobal_state.