Skip to content

Conditional Samplers#

Conditional samplers produce new values for a block given the current state of all other variables. BernoulliConditional handles spin nodes; SoftmaxConditional handles categorical nodes.

hamon.AbstractConditionalSampler #

Base class for all conditional samplers.

A conditional sampler is used to update the state of a block of nodes during each iteration of a sampling algorithm. It takes in the states of all the neighbors and produces a sample for the current block of nodes. This can often be done exactly, but need not be. One could embed MCMC methods within this sampler (to do Metropolis within Gibbs, for example).

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

init() -> None #

Initialize the sampler state before sampling begins.

This is called before the first iteration of block sampling, after which the return of this method is superseded by the return from sample.

Returns:

Type Description
None

the initial sampler state to use for the first iteration of block sampling.

sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState] #

Draw a sample from this conditional.

If this sampler is involved in a block sampling program, this function is called every iteration to update the state of a block of nodes.

Arguments:

  • key: A RNG key that the sampler can use to sample from distributions using jax.random.
  • interactions: A list of interactions that influence the result of this block update. Each interaction is a PyTree. Each array in the PyTree will have shape [n, k, ...], where n is the number of nodes in the block that is being updated and k is the maximum number of times any node in this block was detected as a head node for this interaction.
  • active_flags: A list of arrays of flags that is parallel to interactions. Each array indicates which instances of a given interaction are active for each node in the block. This array has shape [n, k], and is False if a given instance is inactive (which means that it should be ignored during the computation that happens in this function).
  • states: A list of PyTrees that is parallel to interactions, representing the sampling state information that is relevant to computing the influence of each interaction. Every array in each PyTree will have shape [n, k, ...].
  • sampler_state: The current state of this sampler. Will be replaced by the second return from this function the next time it is called.
  • output_sd: A PyTree indicating the expected shape/dtype of the output of this function.

Returns:

A new state for the block of nodes, matching the template given by output_sd.

hamon.AbstractParametricConditionalSampler #

A conditional sampler that leverages a parameterized distribution.

When sample is called, this sampler will first compute a set of parameters, and then use those parameters to draw a sample from some distribution. This workflow is frequently useful in practical cases; for example, to sample from a Gaussian, we can first compute a mean vector and covariance matrix using any procedure, and then draw a sample from the corresponding Gaussian distribution by appropriately transforming a vector of standard normal random variables.

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

init() -> None #

Initialize the sampler state before sampling begins.

This is called before the first iteration of block sampling, after which the return of this method is superseded by the return from sample.

Returns:

Type Description
None

the initial sampler state to use for the first iteration of block sampling.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: PyTree, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree #

Compute the parameters of the distribution. For a description of the arguments, see hamon.AbstractConditionalSampler.sample

sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState] #

Sample from the distribution by first computing the parameters and then generating a sample based off of them.

sample_given_parameters(key: Key, parameters: PyTree, sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState] #

Produce a sample given the parameters of the distribution, passed in as the parameters argument.

hamon.BernoulliConditional #

Sample from a bernoulli distribution.

This sampler is designed to sample from a spin-valued bernoulli distribution:

\[\mathbb{P}(S=s) \propto e^{\gamma s}\]

where \(S\) is a spin-valued random variable, \(s \in \{-1, 1\}\). The parameter \(\gamma\) must be computed by compute_parameters.

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

init() -> None #

Initialize the sampler state before sampling begins.

This is called before the first iteration of block sampling, after which the return of this method is superseded by the return from sample.

Returns:

Type Description
None

the initial sampler state to use for the first iteration of block sampling.

sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState] #

Sample from the distribution by first computing the parameters and then generating a sample based off of them.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree #

A concrete implementation of this function has to return a value of \(\gamma\) for every node in the block that is being updated. This array should have shape [b].

sample_given_parameters(key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], None] #

Sample from a spin-valued bernoulli distribution given the parameter \(\gamma\). In hamon, 1 is represented by the boolean value True and -1 is represented by False.

hamon.SoftmaxConditional #

Sample from a softmax distribution.

This sampler samples from the standard softmax distribution:

\[\mathbb{P}(X=k) \propto e^{\theta_k}\]

where \(X\) is a categorical random variable and \(\theta\) is a vector that parameterizes the relative probabilities of each of the categories.

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

init() -> None #

Initialize the sampler state before sampling begins.

This is called before the first iteration of block sampling, after which the return of this method is superseded by the return from sample.

Returns:

Type Description
None

the initial sampler state to use for the first iteration of block sampling.

sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState] #

Sample from the distribution by first computing the parameters and then generating a sample based off of them.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree #

A concrete implementation of this function has to return $ heta$ vector for every node in the block that is being updated. This array should have shape [b, M], where \(M\) is the number of possible values that \(X\) may take on.

sample_given_parameters(key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], None] #

Sample from a softmax distribution given the parameter vector $ heta$.