Conditional Samplers¤
thrml.AbstractConditionalSampler
¤
Base class for all conditional samplers.
A conditional sampler is used to update the state of a block of nodes during each iteration of a sampling algorithm. It takes in the states of all the neighbors and produces a sample for the current block of nodes. This can often be done exactly, but need not be. One could embed MCMC methods within this sampler (to do Metropolis within Gibbs, for example).
sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']]], sampler_state: ~_SamplerState, output_sd: PyTree[jax._src.core.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], 'State'], ~_SamplerState]
¤
Draw a sample from this conditional.
If this sampler is involved in a block sampling program, this function is called every iteration to update the state of a block of nodes.
Arguments:
key: A RNG key that the sampler can use to sample from distributions usingjax.random.interactions: A list of interactions that influence the result of this block update. Each interaction is a PyTree. Each array in the PyTree will have shape [n, k, ...], where n is the number of nodes in the block that is being updated and k is the maximum number of times any node in this block was detected as a head node for this interaction.active_flags: A list of arrays of flags that is parallel to interactions. Each array indicates which instances of a given interaction are active for each node in the block. This array has shape [n, k], and is False if a given instance is inactive (which means that it should be ignored during the computation that happens in this function).states: A list of PyTrees that is parallel to interactions, representing the sampling state information that is relevant to computing the influence of each interaction. Every array in each PyTree will have shape [n, k, ...].sampler_state: The current state of this sampler. Will be replaced by the second return from this function the next time it is called.output_sd: A PyTree indicating the expected shape/dtype of the output of this function.
Returns:
A new state for the block of nodes, matching the template given by output_sd.
thrml.AbstractParametricConditionalSampler(thrml.AbstractConditionalSampler)
¤
A conditional sampler that leverages a parameterized distribution.
When sample is called, this sampler will first compute a set of parameters, and then use those parameters
to draw a sample from some distribution. This workflow is frequently useful in practical cases; for example, to
sample from a Gaussian, we can first compute a mean vector and covariance matrix using any procedure, and then
draw a sample from the corresponding Gaussian distribution by appropriately transforming a vector of standard
normal random variables.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], 'State']]], sampler_state: PyTree, output_sd: PyTree[jax._src.core.ShapeDtypeStruct]) -> PyTree
¤
Compute the parameters of the distribution. For a description of the arguments, see
thrml.AbstractConditionalSampler.sample
thrml.BernoulliConditional(thrml.AbstractParametricConditionalSampler)
¤
Sample from a bernoulli distribution.
This sampler is designed to sample from a spin-valued bernoulli distribution:
where \(S\) is a spin-valued random variable, \(s \in \{-1, 1\}\). The parameter \(\gamma\) must be
computed by compute_parameters.
thrml.SoftmaxConditional(thrml.AbstractParametricConditionalSampler)
¤
Sample from a softmax distribution.
This sampler samples from the standard softmax distribution:
where \(X\) is a categorical random variable and \(\theta\) is a vector that parameterizes the relative probabilities of each of the categories.