Skip to content

Energy-Based Models¤

This module contains implementations of energy-based models.

thrml.models.AbstractEBM ¤

Something that has a well-defined energy function (map from a state to a scalar).

energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], blocks: list[thrml.Block]) -> Float[Array, ''] ¤

Evaluate the energy function of the EBM given some state information.

Arguments:

  • state: The state for which to evaluate the energy function. Must be compatible with blocks.
  • blocks: Specifies how the information in state is organized.

Returns:

A scalar representing the energy value associated with state.

thrml.models.AbstractFactorizedEBM(thrml.models.AbstractEBM) ¤

An EBM that is made up of Factors, i.e., an EBM with an energy function like,

\[\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)\]

where the sum over \(i\) is taken over factors.

Child classes must define a property which returns a list of factors that substantiate the EBM.

Attributes:

  • node_shape_dtypes: the shape/dtypes of the nodes involved in this EBM. Used to generate the BlockSpec that defines the global state that factors receive to compute energy.

thrml.models.FactorizedEBM(thrml.models.AbstractFactorizedEBM) ¤

An EBM that is defined by a concrete list of factors.

Attributes:

  • _factors: the list of factors that defines this EBM.
__init__(factors: list[thrml.models.EBMFactor], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]] = {thrml.SpinNode: ShapeDtypeStruct(shape=(), dtype=bool), thrml.CategoricalNode: ShapeDtypeStruct(shape=(), dtype=uint8)}) ¤
energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], blocks: list[thrml.Block]) -> Float[Array, ''] ¤

thrml.models.EBMFactor(thrml.AbstractFactor) ¤

A factor that defines an energy function.

energy(global_state: list[Array], block_spec: thrml.BlockSpec) -> Float[Array, ''] ¤

Evaluate the energy function of the factor.

Arguments:

  • global_state: The state information to use to evaluate the energy function. Is a global state of block_spec.
  • block_spec: The BlockSpec used to generate global_state.