THRML

Sampling observers

Observers accumulate statistics over a chain as it runs, so you read off moments or stored states without materializing every sample.

AbstractObserverclass
AbstractObserver()

Interface for objects that inspect the sampling program while it is running.

A concrete Observer is called once per block-sampling iteration and can maintain an arbitrary "carry" state across calls (e.g. running averages, histogram buffers, log-probs, etc.).

StateObserverclass
StateObserver(blocks_to_sample: list[Block])

Observer which logs the raw state of some set of nodes.

Attributes:

  • blocks_to_sample: the list of Blocks which the states are logged for
MomentAccumulatorObserverclass
MomentAccumulatorObserver(moment_spec: Sequence[Sequence[Sequence[AbstractNode]]], f_transform: Callable = <function _f_identity at 0x7cd8aeee7740>)

Observer that accumulates and updates the provided moments.

It doesn't log any samples, and will only accumulate moments. Note that this observer does not scale the accumulated values by the number of times it was called. It simply records a running sum of a product of some state variables,

$$\sum_i f(x_1^i) f(x_2^i) \dots f(x_N^i)$$

Attributes:

  • blocks_to_sample: the blocks to accumulate the moments over. These are for constructing the final state, and aren't truly "blocks" in the algorithmic sense (they can be connected to each other). There is one block per node type.
  • flat_nodes_list: a list of all of the nodes in the moments (each occurring only once, so len(set(x)) = len(x)).
  • flat_to_type_slices_list: a list over node types in which each element is an array of indices of the flat_node_list which that type corresponds to
  • flat_to_full_moment_slices: a list over moment types in which each element is a 2D array, which matches the shape of the moment_spec[i] and of which each element is the index in the flat_node_list.
  • f_transform: the element-wise transformation $f$ to apply to sample values before accumulation.