
THRML is a JAX library for block Gibbs sampling of hypergraphical and energy-based models. Build a model from nodes and many-body factors, divide it into blocks via graph-colouring, and sample, the same structure Extropic's hardware is built to accelerate.
Install with pip install thrml (Python 3.10+), then build a model and sample it:
import jax, jax.numpy as jnp from thrml import SpinNode, Block, SamplingSchedule, sample_states from thrml.models import IsingEBM, IsingSamplingProgram, hinton_init # A 5-spin Ising chain, two-coloured into parallel blocks nodes = [SpinNode() for _ in range(5)] edges = [(nodes[i], nodes[i + 1]) for i in range(4)] model = IsingEBM( nodes, edges, jnp.zeros((5,)), jnp.ones((4,)) * 0.5, jnp.array(1.0)) free_blocks = [Block(nodes[::2]), Block(nodes[1::2])] program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[]) k_init, k_samp = jax.random.split(jax.random.key(0), 2) state = hinton_init(k_init, model, free_blocks, ()) schedule = SamplingSchedule(n_warmup=100, n_samples=1000, steps_per_sample=2) samples = sample_states( k_samp, program, schedule, state, [], [Block(nodes)]) # samples: 1000 draws from the chain, by block Gibbs
4 runnable notebooks, from a first Ising chain to the full sampling stack and hardware-scale spin models.
What the sampling stack is for: real problems compiled to graphical models and sampled on thermodynamic hardware.