Conditional Samplers#
Conditional samplers produce new values for a block given the current state
of all other variables. BernoulliConditional handles spin nodes;
SoftmaxConditional handles categorical nodes.
hamon.AbstractConditionalSampler
#
Base class for all conditional samplers.
A conditional sampler is used to update the state of a block of nodes during each iteration of a sampling algorithm. It takes in the states of all the neighbors and produces a sample for the current block of nodes. This can often be done exactly, but need not be. One could embed MCMC methods within this sampler (to do Metropolis within Gibbs, for example).
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
init() -> None
#
Initialize the sampler state before sampling begins.
This is called before the first iteration of block sampling, after which the return of this method is
superseded by the return from sample.
Returns:
| Type | Description |
|---|---|
None
|
the initial sampler state to use for the first iteration of block sampling. |
sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState]
#
Draw a sample from this conditional.
If this sampler is involved in a block sampling program, this function is called every iteration to update the state of a block of nodes.
Arguments:
key: A RNG key that the sampler can use to sample from distributions usingjax.random.interactions: A list of interactions that influence the result of this block update. Each interaction is a PyTree. Each array in the PyTree will have shape [n, k, ...], where n is the number of nodes in the block that is being updated and k is the maximum number of times any node in this block was detected as a head node for this interaction.active_flags: A list of arrays of flags that is parallel to interactions. Each array indicates which instances of a given interaction are active for each node in the block. This array has shape [n, k], and is False if a given instance is inactive (which means that it should be ignored during the computation that happens in this function).states: A list of PyTrees that is parallel to interactions, representing the sampling state information that is relevant to computing the influence of each interaction. Every array in each PyTree will have shape [n, k, ...].sampler_state: The current state of this sampler. Will be replaced by the second return from this function the next time it is called.output_sd: A PyTree indicating the expected shape/dtype of the output of this function.
Returns:
A new state for the block of nodes, matching the template given by output_sd.
hamon.AbstractParametricConditionalSampler
#
A conditional sampler that leverages a parameterized distribution.
When sample is called, this sampler will first compute a set of parameters, and then use those parameters
to draw a sample from some distribution. This workflow is frequently useful in practical cases; for example, to
sample from a Gaussian, we can first compute a mean vector and covariance matrix using any procedure, and then
draw a sample from the corresponding Gaussian distribution by appropriately transforming a vector of standard
normal random variables.
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
init() -> None
#
Initialize the sampler state before sampling begins.
This is called before the first iteration of block sampling, after which the return of this method is
superseded by the return from sample.
Returns:
| Type | Description |
|---|---|
None
|
the initial sampler state to use for the first iteration of block sampling. |
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: PyTree, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree
#
Compute the parameters of the distribution. For a description of the arguments, see
hamon.AbstractConditionalSampler.sample
sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState]
#
Sample from the distribution by first computing the parameters and then generating a sample based off of them.
sample_given_parameters(key: Key, parameters: PyTree, sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState]
#
Produce a sample given the parameters of the distribution, passed in as the parameters argument.
hamon.BernoulliConditional
#
Sample from a bernoulli distribution.
This sampler is designed to sample from a spin-valued bernoulli distribution:
where \(S\) is a spin-valued random variable, \(s \in \{-1, 1\}\). The parameter \(\gamma\) must be
computed by compute_parameters.
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
init() -> None
#
Initialize the sampler state before sampling begins.
This is called before the first iteration of block sampling, after which the return of this method is
superseded by the return from sample.
Returns:
| Type | Description |
|---|---|
None
|
the initial sampler state to use for the first iteration of block sampling. |
sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState]
#
Sample from the distribution by first computing the parameters and then generating a sample based off of them.
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree
#
A concrete implementation of this function has to return a value of \(\gamma\) for every node in the block that is being updated. This array should have shape [b].
sample_given_parameters(key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], None]
#
Sample from a spin-valued bernoulli distribution given the parameter \(\gamma\). In hamon,
1 is represented by the boolean value True and -1 is represented by False.
hamon.SoftmaxConditional
#
Sample from a softmax distribution.
This sampler samples from the standard softmax distribution:
where \(X\) is a categorical random variable and \(\theta\) is a vector that parameterizes the relative probabilities of each of the categories.
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
init() -> None
#
Initialize the sampler state before sampling begins.
This is called before the first iteration of block sampling, after which the return of this method is
superseded by the return from sample.
Returns:
| Type | Description |
|---|---|
None
|
the initial sampler state to use for the first iteration of block sampling. |
sample(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: ~_SamplerState, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], ~_SamplerState]
#
Sample from the distribution by first computing the parameters and then generating a sample based off of them.
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
compute_parameters(key: Key, interactions: list[PyTree], active_flags: list[Array], states: list[list[PyTree[Shaped[Array, 'nodes ?*state'], State]]], sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> PyTree
#
A concrete implementation of this function has to return $ heta$ vector for every node in the block that is being updated. This array should have shape [b, M], where \(M\) is the number of possible values that \(X\) may take on.
sample_given_parameters(key: Key, parameters: PyTree, sampler_state: None, output_sd: PyTree[jax.ShapeDtypeStruct]) -> tuple[PyTree[Shaped[Array, 'nodes ?*state'], State], None]
#
Sample from a softmax distribution given the parameter vector $ heta$.