Skip to content

Block Sampling¤

thrml.BlockGibbsSpec(thrml.BlockSpec) ¤

A BlockGibbsSpec is a type of BlockSpec which contains additional information on free and clamped blocks.

This entity also supports SuperBlocks, which are merely groups of blocks which are sampled at the same time algorithmically, but not programmatically. That is to say, superblock = (block1, block2) means that the states input to block1 and block2 are the same, but they are not executed at the same time. This may be because they are the same color on a graph, but require vastly different sampling methods such that JAX SIMD approaches are not feasible to parallelize them.

A recurring theme in thrml is the importance of implicit indexing. One such example can be seen here. Because global states are created by concatenating lists of free and clamped blocks, providing the inputs in the same order as the blocks are defined is essential. This is almost always taken care of internally, but when writing custom functions or interfaces this is important to keep in mind.

Attributes:

  • free_blocks: the list of free blocks (in order)
  • sampling_order: a list of len(superblocks) lists, where each sampling_order[i] is the index of free_blocks to sample. Sampling is done by iterating over this order and sampling each sublist of free blocks at the same algorithmic time.
  • clamped_blocks: the list of clamped blocks
  • superblocks: the list of superblocks
__init__(free_super_blocks: Sequence[tuple[thrml.Block, ...] | thrml.Block], clamped_blocks: list[thrml.Block], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]] = {thrml.SpinNode: ShapeDtypeStruct(shape=(), dtype=bool), thrml.CategoricalNode: ShapeDtypeStruct(shape=(), dtype=uint8)}) ¤

Create a Gibbs specification from free and clamped blocks.

Arguments:

  • free_super_blocks: An ordered sequence where each element is either a single Block, or a tuple of blocks that must share the same global state when calling their individual samplers.
  • clamped_blocks: Blocks whose nodes stay fixed during sampling.
  • node_shape_dtypes: Mapping from node class to a PyTree of jax.ShapeDtypeStruct; identical to the argument in BlockSpec.

thrml.BlockSamplingProgram ¤

A PGM block-sampling program.

This class encapsulates everything that is needed to run a PGM block sampling program in THRML. per_block_interactions and per_block_interaction_active are parallel to the free blocks in gibbs_spec, and their members are passed directly to a sampler when the state of the corresponding free block is being updated during a sampling program. per_block_interaction_global_inds and per_block_interaction_global_slices are also parallel to the free blocks, and are used to slice the global state of the program to produce the state information required to update the state of each block alongside the static information contained in the interactions.

Attributes:

  • gibbs_spec: A division of some PGM into free and clamped blocks.
  • samplers: A sampler to use to update every free block in gibbs_spec.
  • per_block_interactions: All the interactions that touch each free block in gibbs_spec.
  • per_block_interaction_active: indicates which interactions are real and which interactions are not part of the model and have been added to pad data structures so that they can be rectangular.
  • per_block_interaction_global_inds: how to find the information required to update each block within the global state list
  • per_block_interaction_global_slices: how to slice each array in the global state list to find the information required to update each block
__init__(gibbs_spec: thrml.BlockGibbsSpec, samplers: list[thrml.AbstractConditionalSampler], interaction_groups: list[thrml.InteractionGroup]) ¤

Construct a BlockSamplingProgram.

This code is the beating heart of THRML, and the chance that you should be modifying it or trying to understand it deeply are very low (as this would basically correspond to re-writing the library). This code takes in a set of information that implicitly defines a sampling program and manipulates it into a shape that is appropriate for practical vectorized block-sampling program. This involves reindexing, slicing, and often padding.

Arguments:

  • gibbs_spec: A division of some PGM into free and clamped blocks.
  • samplers: The update rule to use for each free block in gibbs_spec.
  • interaction_groups: A list of InteractionGroups that define how the variables in your sampling program affect one another.

thrml.SamplingSchedule ¤

Represents a sampling schedule for a process.

Attributes:

  • n_warmup: The number of warmup steps to run before collecting samples.
  • n_samples: The number of samples to collect.
  • steps_per_sample: The number of steps to run between each sample.

thrml.sample_blocks(key: Key[Array, ''], state_free: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], clamp_state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], program: thrml.BlockSamplingProgram, sampler_state: list[~_SamplerState]) -> tuple[list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], list[~_SamplerState]] ¤

Perform one iteration of sampling, visiting every block.

Arguments:

  • key: The JAX PRNG key.
  • state_free: The state of the free blocks.
  • clamp_state: The state of the clamped blocks.
  • program: The Gibbs program.
  • sampler_state: The state of the sampler.

Returns:

  • Updated free-block state list and sampler-state list.

thrml.sample_single_block(key: Key[Array, ''], state_free: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], clamp_state: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], program: thrml.BlockSamplingProgram, block: int, sampler_state: ~_SamplerState, global_state: list[PyTree] | None = None) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], '_State'], ~_SamplerState] ¤

Samples a single block within a Gibbs sampling program based on the current states and program configurations. It extracts neighboring states, processes required data, and applies a sampling function to generate output samples.

Arguments:

  • key: Pseudo-random number generator key to ensure reproducibility of sampling.
  • state_free: Current states of free blocks, representing the values to be updated during sampling.
  • clamp_state: Clamped states that remain fixed during the sampling process.
  • program: The Gibbs sampling program containing specifications, samplers, neighborhood information, and parameters.
  • block: Index of the block to be sampled in the current iteration.
  • sampler_state: The current state of the sampler that will be used to perform the update.
  • global_state: Optionally precomputed global state for the concatenated free and clamped blocks; when omitted the function constructs it internally.

Returns:

  • Updated block state and sampler state for the specified block.

thrml.sample_with_observation(key: Key[Array, ''], program: thrml.BlockSamplingProgram, schedule: thrml.SamplingSchedule, init_chain_state: list[PyTree[Shaped[Array, 'nodes ?*state']]], state_clamp: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], observation_carry_init: ~ObserveCarry, f_observe: thrml.AbstractObserver) -> tuple[~ObserveCarry, list[PyTree[Shaped[Array, 'n_samples nodes ?*state']]]] ¤

Run the full chain and call an Observer after every recorded sample.

Arguments:

  • key: RNG key.
  • program: The sampling program.
  • schedule: Warm-up length, number of samples, number of steps between samples.
  • init_chain_state: Initial free-block state.
  • state_clamp: Clamped-block state.
  • observation_carry_init: Initial carry handed to f_observe.
  • f_observe: Observer instance.

Returns:

  • Tuple (final_observer_carry, samples) where samples is a PyTree whose leading axis has size schedule.n_samples.

thrml.sample_states(key: Key[Array, ''], program: thrml.BlockSamplingProgram, schedule: thrml.SamplingSchedule, init_state_free: list[PyTree[Shaped[Array, 'nodes ?*state']]], state_clamp: list[PyTree[Shaped[Array, 'nodes ?*state'], '_State']], nodes_to_sample: list[thrml.Block]) -> list[PyTree[Shaped[Array, 'n_samples nodes ?*state']]] ¤

Convenience wrapper to collect state information for nodes_to_sample only.

Internally builds a thrml.StateObserver, runs thrml.sample_with_observation, and returns a stacked tensor of shape (schedule.n_samples, ...).