THRML

Ising models

A ready-made Ising energy-based model, its sampling program, and the utilities to initialize, train, and estimate its moments.

IsingEBMclass
IsingEBM(nodes: list[AbstractNode], edges: list[tuple[AbstractNode, AbstractNode]], biases: Array, weights: Array, beta: Array)

An EBM with the energy function,

$$\mathcal{E}(s) = -\beta \left( \sum_{i \in S_1} b_i s_i + \sum_{(i, j) \in S_2} J_{ij} s_i s_j \right)$$

where $S_1$ and $S_2$ are the sets of biases and weights that make up the model, respectively. $b_i$ represents the bias associated with the spin $s_i$ and $J_{ij}$ is a weight that couples $s_i$ and $s_j$. $\beta$ is the usual temperature parameter.

Attributes:

  • nodes: the nodes that have an associated bias (i.e $S_1$)
  • biases: the bias associated with each node in nodes.
  • edges: the edges that have an associated weight (i.e $S_2$)
  • weights: the weight associated with each pair of nodes in edges.
  • beta: the scalar temperature parameter for the model.
IsingSamplingProgramclass
IsingSamplingProgram(ebm: IsingEBM, free_blocks: list[tuple[Block, ...] | Block], clamped_blocks: list[Block])

A very thin wrapper on FactorSamplingProgram that specializes it to the case of an Ising Model.

IsingTrainingSpecclass
IsingTrainingSpec(ebm: IsingEBM, data_blocks: list[Block], conditioning_blocks: list[Block], positive_sampling_blocks: list[tuple[Block, ...] | Block], negative_sampling_blocks: list[tuple[Block, ...] | Block], schedule_positive: SamplingSchedule, schedule_negative: SamplingSchedule)

Contains a complete specification of an Ising EBM that can be trained using sampling-based gradients.

Defines sampling programs and schedules that allow for collection of the positive and negative phase samples required for Monte Carlo estimation of the gradient of the KL-divergence between the model and a data distribution.

hinton_initfunction
hinton_init(key: Key[Array, ''], model: IsingEBM, blocks: list[Block[AbstractNode]], batch_shape: tuple[int]) -> list[Bool[Array, 'batch_size block_size']]

Initialize the blocks according to the marginal bias.

Each binary unit $i$ in a block is sampled independently as

$$\mathbb{P}(S_i = 1) = \sigma(\beta h_i) = \frac{1}{1 + e^{-\beta h_i}}$$

where $h_i$ is the bias of unit *i* and $\beta$ is the inverse-temperature scaling factor. See Hinton (2012) for a discussion of this initialization heuristic.

Arguments: key: the JAX PRNG key to use model: the Ising model to initialize for blocks: the blocks that are to be initialized batch_shape: the pre-pended dimension

Returns: the initialized blocks

estimate_momentsfunction
estimate_moments(key: Key[Array, ''], first_moment_nodes: list[AbstractNode], second_moment_edges: list[tuple[AbstractNode, AbstractNode]], program: BlockSamplingProgram, schedule: SamplingSchedule, init_state: list[Array], clamped_data: list[Array])

Estimates the first and second moments of an Ising model Boltzmann distribution via sampling.

Arguments: key: the jax PRNG key first_moment_nodes: the nodes that represent the variables we want to estimate the first moments of second_moment_edges: the edges that connect the variables we want to estimate the second moments of program: the BlockSamplingProgram to be used for sampling schedule: the schedule to use for sampling init_state: the variable values to use to initialize the sampling clamped_data: the variable values to assign to the clamped nodes Returns: the first and second moment data

estimate_kl_gradfunction
estimate_kl_grad(key: Key[Array, ''], training_spec: IsingTrainingSpec, bias_nodes: list[AbstractNode], weight_edges: list[tuple[AbstractNode, AbstractNode]], data: list[Array], conditioning_values: list[Array], init_state_positive: list[Array], init_state_negative: list[Array]) -> tuple

Estimate the KL-gradients of an Ising model with respect to its weights and biases.

Uses the standard two-term Monte Carlo estimator of the gradient of the KL-divergence between an Ising model and a data distribution

The gradients are:

$$\Delta W = -\beta (\langle s_i s_j \rangle_{+} - \langle s_i s_j \rangle_{-})$$
$$\Delta b = -\beta (\langle s_i \rangle_{+} - \langle s_i \rangle_{-})$$

Here, $\langle\cdot\rangle_{+}$ denotes an expectation under the *positive* phase (data-clamped Boltzmann distribution) and $\langle\cdot\rangle_{-}$ under the *negative* phase (model distribution).

Arguments: key: the JAX PRNG key training_spec: the Ising EBM for which to estimate the gradients bias_nodes: the nodes for which to estimate the bias gradients weight_edges: the edges for which to estimate the weight gradients data: The data values to use for the positive phase of the gradient estimate. Each array has shape [batch nodes] conditioning_values: values to assign to the nodes that the model is conditioned on. Each array has shape [nodes] init_state_positive: initial state for the positive sampling chain. Each array has shape [n_chains_pos batch nodes] init_state_negative: initial state for the negative sampling chain. Each array has shape [n_chains_neg nodes] Returns: the weight gradients and the bias gradients