Skip to content

Energy-Based Models#

Base classes for energy-based models. AbstractEBM defines the interface; AbstractFactorizedEBM adds discrete-specific factor construction.

hamon.models.AbstractEBM #

Something that has a well-defined energy function (map from a state to a scalar).

energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], _State]], blocks: BlockSpec | list[Block]) -> Float[Array, ''] #

Evaluate the energy function of the EBM given some state information.

Arguments:

  • state: The state for which to evaluate the energy function. Must be compatible with blocks.
  • blocks: Specifies how the information in state is organized. May be either a pre-built BlockSpec (fast path — avoids rebuilding the spec) or a plain list[Block] for convenience when calling from user code.

Returns:

A scalar representing the energy value associated with state.

hamon.models.AbstractFactorizedEBM #

An EBM that is made up of Factors, i.e., an EBM with an energy function like,

\[\mathcal{E}(x) = \sum_i \mathcal{E}^i(x)\]

where the sum over \(i\) is taken over factors.

Child classes must define a property which returns a list of factors that substantiate the EBM.

Attributes:

  • node_shape_dtypes: the shape/dtypes of the nodes involved in this EBM. Used to generate the BlockSpec that defines the global state that factors receive to compute energy.
__init__(node_shape_dtypes: typing.Mapping[typing.Type[hamon.pgm.AbstractNode], PyTree[jax.ShapeDtypeStruct]] = {<class 'hamon.pgm.SpinNode'>: ShapeDtypeStruct(shape=(), dtype=bool), <class 'hamon.pgm.CategoricalNode'>: ShapeDtypeStruct(shape=(), dtype=uint8)}) #