Skip to content

Observers#

Observers collect statistics during sampling. StateObserver records raw states; MomentAccumulatorObserver computes running means and variances without storing every sample.

hamon.AbstractObserver #

Interface for objects that inspect the sampling program while it is running.

A concrete Observer is called once per block-sampling iteration and can maintain an arbitrary "carry" state across calls (e.g. running averages, histogram buffers, log-probs, etc.).

init() -> PyTree #

Initialize the memory for the observer. Defaults to None.

hamon.StateObserver #

Observer which logs the raw state of some set of nodes.

This observer is stateless: its carry is always None and iteration is ignored.

Attributes:

  • blocks_to_sample: the list of Blocks which the states are logged for
__init__(blocks_to_sample: list[hamon.block_management.Block]) -> None #

Initialize self. See help(type(self)) for accurate signature.

hamon.MomentAccumulatorObserver #

Observer that accumulates and updates the provided moments.

It doesn't log any samples, and will only accumulate moments. Note that this observer does not scale the accumulated values by the number of times it was called. It simply records a running sum of a product of some state variables,

\[\sum_i f(x_1^i) f(x_2^i) \dots f(x_N^i)\]

Attributes:

  • blocks_to_sample: the blocks to accumulate the moments over. These are for constructing the final state, and aren't truly "blocks" in the algorithmic sense (they can be connected to each other). There is one block per node type.
  • flat_nodes_list: a list of all of the nodes in the moments (each occurring only once, so len(set(x)) = len(x)).
  • flat_to_type_slices_list: a list over node types in which each element is an array of indices of the flat_node_list which that type corresponds to
  • flat_to_full_moment_slices: a list over moment types in which each element is a 2D array, which matches the shape of the moment_spec[i] and of which each element is the index in the flat_node_list.
  • f_transform: the element-wise transformation \(f\) to apply to sample values before accumulation.
  • _flat_scatter_index: precomputed concatenation of all flat_to_type_slices_list arrays, used to build flat_state in a single scatter call.
  • _flat_scatter_sizes: number of entries contributed by each node type, used to split the concatenated sampled state before scattering.
  • _flat_value_order: precomputed argsort(_flat_scatter_index); used in __call__ to permute the concatenated sampled values into flat-node order without allocating a zeros array.
  • _accumulate_dtype: dtype for the accumulator, fixed at construction time.
__init__(moment_spec: typing.Sequence[typing.Sequence[typing.Sequence[hamon.pgm.AbstractNode]]], f_transform: typing.Callable = _f_identity, dtype: dtype = float32) #

Create a MomentAccumulatorObserver.

Arguments:

  • moment_spec: A 3 depth sequence. The first is a sequence over different moment types. A given moment type should have the same number of nodes in each moment. Then for each moment type, there is a sequence over moments. Each given moment is defined by a certain set of nodes.

    For example, to get the first and second moments on a simple o-o graph:

    [ [(node1,), (node2,)], [(node1, node2)] ]

  • f_transform: A function that takes in (state, blocks) and returns something with the same structure as state. Defines a transformation \(y=f(x)\) so accumulated moments are \(\langle f(x_1) f(x_2) \rangle\).

  • dtype: Accumulator dtype, fixed at construction. Defaults to jnp.float32. Use jnp.float64 for double-precision models. Fixing this here avoids a per-step cast inside the scan body.