Ising Models¤
This module contains implementations of Ising models and spin systems.
thrml.models.IsingEBM(thrml.models.AbstractFactorizedEBM)
¤
An EBM with the energy function,
where \(S_1\) and \(S_2\) are the sets of biases and weights that make up the model, respectively. \(b_i\) represents the bias associated with the spin \(s_i\) and \(J_{ij}\) is a weight that couples \(s_i\) and \(s_j\). \(\beta\) is the usual temperature parameter.
Attributes:
nodes: the nodes that have an associated bias (i.e \(S_1\))biases: the bias associated with each node innodes.edges: the edges that have an associated weight (i.e \(S_2\))weights: the weight associated with each pair of nodes inedges.beta: the scalar temperature parameter for the model.
__init__(nodes: list[thrml.AbstractNode], edges: list[tuple[thrml.AbstractNode, thrml.AbstractNode]], biases: Array, weights: Array, beta: Array)
¤
Initialize an Ising EBM.
Arguments:
nodes: List of nodes with associated biasesedges: List of edge pairs with associated weightsbiases: Bias values for each nodeweights: Weight values for each edgebeta: Temperature parameter
thrml.models.IsingSamplingProgram(thrml.FactorSamplingProgram)
¤
A very thin wrapper on FactorSamplingProgram that specializes it to the case of an Ising Model.
__init__(ebm: thrml.models.IsingEBM, free_blocks: list[tuple[thrml.Block, ...] | thrml.Block], clamped_blocks: list[thrml.Block])
¤
Initialize an Ising sampling program.
Arguments:
ebm: The Ising EBM to sample fromfree_blocks: List of super blocks that are free to varyclamped_blocks: List of blocks that are held fixed
thrml.models.IsingTrainingSpec
¤
Contains a complete specification of an Ising EBM that can be trained using sampling-based gradients.
Defines sampling programs and schedules that allow for collection of the positive and negative phase samples required for Monte Carlo estimation of the gradient of the KL-divergence between the model and a data distribution.
__init__(ebm: thrml.models.IsingEBM, data_blocks: list[thrml.Block], conditioning_blocks: list[thrml.Block], positive_sampling_blocks: list[tuple[thrml.Block, ...] | thrml.Block], negative_sampling_blocks: list[tuple[thrml.Block, ...] | thrml.Block], schedule_positive: thrml.SamplingSchedule, schedule_negative: thrml.SamplingSchedule)
¤
thrml.models.hinton_init(key: Key[Array, ''], model: thrml.models.IsingEBM, blocks: list[thrml.Block[thrml.AbstractNode]], batch_shape: tuple[int]) -> list[Bool[Array, 'batch_size block_size']]
¤
Initialize the blocks according to the marginal bias.
Each binary unit \(i\) in a block is sampled independently as
where \(h_i\) is the bias of unit i and \(\beta\) is the inverse-temperature scaling factor. See Hinton (2012) for a discussion of this initialization heuristic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
the JAX PRNG key to use |
required |
model
|
thrml.models.IsingEBM
|
the Ising model to initialize for |
required |
blocks
|
list[thrml.Block[thrml.AbstractNode]]
|
the blocks that are to be initialized |
required |
batch_shape
|
tuple[int]
|
the pre-pended dimension |
required |
Returns:
| Type | Description |
|---|---|
list[Bool[Array, 'batch_size block_size']]
|
the initialized blocks |
thrml.models.estimate_moments(key: Key[Array, ''], first_moment_nodes: list[thrml.AbstractNode], second_moment_edges: list[tuple[thrml.AbstractNode, thrml.AbstractNode]], program: thrml.BlockSamplingProgram, schedule: thrml.SamplingSchedule, init_state: list[Array], clamped_data: list[Array])
¤
Estimates the first and second moments of an Ising model Boltzmann distribution via sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
the jax PRNG key |
required |
first_moment_nodes
|
list[thrml.AbstractNode]
|
the nodes that represent the variables we want to estimate the first moments of |
required |
second_moment_edges
|
list[tuple[thrml.AbstractNode, thrml.AbstractNode]]
|
the edges that connect the variables we want to estimate the second moments of |
required |
program
|
thrml.BlockSamplingProgram
|
the |
required |
schedule
|
thrml.SamplingSchedule
|
the schedule to use for sampling |
required |
init_state
|
list[Array]
|
the variable values to use to initialize the sampling |
required |
clamped_data
|
list[Array]
|
the variable values to assign to the clamped nodes |
required |
Returns: the first and second moment data
thrml.models.estimate_kl_grad(key: Key[Array, ''], training_spec: thrml.models.IsingTrainingSpec, bias_nodes: list[thrml.AbstractNode], weight_edges: list[tuple[thrml.AbstractNode, thrml.AbstractNode]], data: list[Array], conditioning_values: list[Array], init_state_positive: list[Array], init_state_negative: list[Array]) -> tuple
¤
Estimate the KL-gradients of an Ising model with respect to its weights and biases.
Uses the standard two-term Monte Carlo estimator of the gradient of the KL-divergence between an Ising model and a data distribution
The gradients are:
Here, \(\langle\cdot\rangle_{+}\) denotes an expectation under the positive phase (data-clamped Boltzmann distribution) and \(\langle\cdot\rangle_{-}\) under the negative phase (model distribution).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key[Array, '']
|
the JAX PRNG key |
required |
training_spec
|
thrml.models.IsingTrainingSpec
|
the Ising EBM for which to estimate the gradients |
required |
bias_nodes
|
list[thrml.AbstractNode]
|
the nodes for which to estimate the bias gradients |
required |
weight_edges
|
list[tuple[thrml.AbstractNode, thrml.AbstractNode]]
|
the edges for which to estimate the weight gradients |
required |
data
|
list[Array]
|
The data values to use for the positive phase of the gradient estimate. Each array has shape [batch nodes] |
required |
conditioning_values
|
list[Array]
|
values to assign to the nodes that the model is conditioned on. Each array has shape [nodes] |
required |
init_state_positive
|
list[Array]
|
initial state for the positive sampling chain. Each array has shape [n_chains_pos batch nodes] |
required |
init_state_negative
|
list[Array]
|
initial state for the negative sampling chain. Each array has shape [n_chains_neg nodes] |
required |
Returns: the weight gradients and the bias gradients