Block Management¤
thrml.Block
¤
A Block is the basic unit through which Gibbs sampling can operate.
Each block represents a collection of nodes that can efficiently be sampled simultaneously in a JAX-friendly SIMD manner. In THRML, this means that the nodes must all be of the same type.
Attributes:
nodes: the tuple of nodes that this block contains
thrml.BlockSpec
¤
This contains the necessary mappings for logging indices of states and node types.
This helps convert between block states and global states. A block state is a list of pytrees, where each pytree leaf has shape[0] = number of nodes in the block. The length of the block state is the number of blocks. The global state is a flattened version of this. Each pytree type is combined (regardless of which block they are in), to make a list of pytrees where each leaf shape[0] is the total number of nodes of that pytree shape. As an example, imagine an Ising model, every node is the same pytree (just a scalar array), as such the block state is a list of arrays where each array is the state of the block and the global state would be a length-1 list that contains an array of shape (total_nodes,).
Why is this global/block representation necessary? The answer is that the global representation is preferred for operating over in many JAX cases, but requires careful indexing (to know where in this long array each block resides) and thus the block representation is more natural/easy to use for many users. Why is the global state easier to work with? Well consider sampling, in order to sample a block (or even just a node) we need to collect all the states of the neighboring nodes. If we only had the block state we would have to loop over the block state and collect from each block the neighbors, we would then pass this to the sampler. The sampler would then have to know the type of each block (to know what to do with the states) then for loop over the blocks in order to collect each. This (programmatically) is fine, but results in additional for loops that slow down JAX, compared to gathering indexes from a single array.
Attributes:
blocks: the list of blocks this spec containsall_block_sds: a SD is a single_PyTreeStruct. Each node/block has only one SD associated with it, but each node can have neighbors of many types. This is the SD of each block (in the same order as blocks, this internal ordering is quite important for bookkeeping). This list is just the list of SDs for each block (and thus has length = len(blocks)).global_sd_order: the list of SDs, providing a SoT for the global orderingsd_index_map: a dictionary mapping the SD to an integer in theglobal_sd_order. This is like calling.indexon it.node_global_location_map: a dictionary mapping a given node to a tuple. That tuple contains the global index (i.e. which element in the global list it is in) and the relative position in that pytree. That is to say, you can get the state of the node viamap(x[tuple[1]], global_repr[tuple[0]])block_to_global_slice_spec: a list over unique SDs (so length global_sd_order), where each list inside this is the list over blocks which contain that pytree. E.g. [[0, 1], [2]] indicates that blocks[0] and blocks[1] are both of pytree SD 0.node_shape_dtypes: a dictionary mapping node types to hashable_PyTreeStructnode_shape_struct: a dictionary mapping node types to pytrees of JAX-shaped dtype structs (just for user access, since the keys aren't hashable that creates issues for JAX in other areas.)
__init__(blocks: list[thrml.Block], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]])
¤
Create a BlockSpec from blocks.
Based on the information passed in via node_shape_dtypes, determine the minimal global state that can be used to represent the blocks.
Arguments:
blocks: the list ofBlocks that this specification operates onnode_shape_dtypes: the mapping of node types to their structures. This should be a pytree ofjax.ShapeDtypeStructs.
thrml.block_state_to_global(block_state: list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']], spec: thrml.BlockSpec) -> list[PyTree[Shaped[Array, 'nodes_global ?*state'], '_GlobalState']]
¤
Convert block-local state to the global stacked representation.
The block representation is a list where block_state[i] contains the
state of spec.blocks[i] and every node occupies index 0 of its leaf.
The global representation is a shorter list (one entry per distinct PyTree structure) in which all blocks with the same structure are concatenated along their node axis.
Arguments:
block_state: State organised per block, same length asspec.blocks.spec: Thethrml.BlockSpecthat defines the mapping.
Returns:
A list whose length equals
len(spec.global_sd_order)—the stacked global state.
thrml.get_node_locations(nodes: thrml.Block, spec: thrml.BlockSpec) -> tuple[int, Int[Array, 'nodes']]
¤
Locate a contiguous set of nodes inside the global state.
Arguments:
nodes: Athrml.Blockwhose nodes you want locations for.spec: Thethrml.BlockSpecgenerated from the same graph.
Returns:
Tuple (sd_index, positions) where
- sd_index is the position inside the global list returned by
thrml.block_state_to_global, and - positions is a 1D array with the indices each node occupies inside that particular PyTree.
thrml.from_global_state(global_state: list[PyTree[Shaped[Array, 'nodes_global ?*state'], '_GlobalState']], spec_from: thrml.BlockSpec, blocks_to_extract: list[thrml.Block]) -> list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']]
¤
Extract the states for a subset of blocks from a global state.
Arguments:
global_state: A state produced by [thrml.block_state_to_global(spec_from)][].spec_from: Thethrml.BlockSpecassociated with global_state.blocks_to_extract: The blocks whose node states should be returned.
Returns:
A list with one element per blocks_to_extract—each element is a PyTree
with exactly len(block) nodes in its leading dimension.
thrml.make_empty_block_state(blocks: list[thrml.Block], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]], batch_shape: tuple | None = None) -> list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']]
¤
Allocate a zero-initialised block state.
Arguments:
blocks: All blocks in the graph (order is preserved).node_shape_dtypes: Maps every node class to itsjax.ShapeDtypeStructPyTree template.batch_shape: Optional batch dimension(s) to prepend to every leaf.
Returns:
A list of PyTrees—one per block—whose leaves are
zeros(batch_shape + (len(block),) + leaf.shape).
thrml.verify_block_state(blocks: list[thrml.Block], states: list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']], node_shape_dtypes: Mapping[type[thrml.AbstractNode], PyTree[jax._src.core.ShapeDtypeStruct]], block_axis: int | None = None) -> None
¤
Check that a state is what it should be given some blocks and node shape/dtypes.
Passing incompatible state information into THRML functions can lead to unintended casting/other weird silent errors, so we should always check this.
Arguments:
blocks: A list of Blocks.states: A list of states to verify against blocks.node_shape_dtypes: Maps every node class to itsjax.ShapeDtypeStructPyTree template.block_axis: Index in the state batch shape at which to expect the block length.
Returns:
None. Raises RuntimeError if blocks and states are incompatible.