Skip to content

NRPT#

Core NRPT functions: single runs, adaptive schedule tuning, schedule optimization, and automatic chain count discovery.

hamon.nrpt #

Non-Reversible Parallel Tempering with vectorized swaps.

Based on Syed et al. (2021), "Non-Reversible Parallel Tempering: a Scalable Highly Parallel MCMC Scheme" (arXiv:1905.02939).

Exploits temperature-linearity (E_β = β·E_base) for single-eval-per-chain swap decisions. Adaptive schedule optimization (Algorithm 4) equalizes rejection rates. Optional energy caching with boundary-only deltas for rectangular block partitions.

__cached__ module-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__doc__ module-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__file__ module-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__name__ module-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__package__ module-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

annotations module-attribute #
NRPTCarry #

Scan carry for the NRPT inner loop.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__orig_bases__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__slots__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

accepted class-attribute #

Alias for field number 2

attempted class-attribute #

Alias for field number 3

base_E class-attribute #

Alias for field number 5

idx_state class-attribute #

Alias for field number 4

key class-attribute #

Alias for field number 0

obs_carry class-attribute #

Alias for field number 6

states class-attribute #

Alias for field number 1

AbstractEBM #

Something that has a well-defined energy function (map from a state to a scalar).

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], _State]], blocks: BlockSpec | list[Block]) -> Float[Array, ''] #

Evaluate the energy function of the EBM given some state information.

Arguments:

  • state: The state for which to evaluate the energy function. Must be compatible with blocks.
  • blocks: Specifies how the information in state is organized. May be either a pre-built BlockSpec (fast path — avoids rebuilding the spec) or a plain list[Block] for convenience when calling from user code.

Returns:

A scalar representing the energy value associated with state.

with_beta(beta: Array) -> AbstractEBM #

Return a copy of this EBM with a different inverse-temperature β.

Subclasses that want to work with nrpt_adaptive(ebm=..., program=...) must override this method.

AbstractNRPTObserver #

Observer for NRPT rounds, called once per round after Gibbs sweeps and swaps.

Concrete subclasses must implement __call__ and may override init to provide a non-trivial carry. The observation returned by __call__ is stacked by lax.scan into a pytree with a leading axis of size n_rounds. Return None as the observation for accumulate-only observers that do not need per-round storage.

__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__call__(stacked_states: list[Array], base_energies: Array, round_idx: Int[Array, ''], carry: ~ObserveCarry) -> tuple[~ObserveCarry, PyTree] #

Observe one NRPT round.

Arguments:

  • stacked_states: Per-block arrays, each of shape (n_chains, ...). The cold chain (target) is at index -1.
  • base_energies: Shape (n_chains,) base energies (no β factor).
  • round_idx: Zero-based round counter.
  • carry: Arbitrary pytree state threaded across rounds.

Returns:

(updated_carry, observation)observation is stacked by lax.scan; use None for accumulate-only mode.

__init__() -> None #

Initialize self. See help(type(self)) for accurate signature.

init() -> PyTree #

Initialize the observer carry. Defaults to None.

BlockSamplingProgram #

A PGM block-sampling program.

This class encapsulates everything that is needed to run a PGM block sampling program in hamon. per_block_interactions and per_block_interaction_active are parallel to the free blocks in gibbs_spec, and their members are passed directly to a sampler when the state of the corresponding free block is being updated during a sampling program. per_block_interaction_global_inds and per_block_interaction_global_slices are also parallel to the free blocks, and are used to slice the global state of the program to produce the state information required to update the state of each block alongside the static information contained in the interactions.

Attributes:

  • gibbs_spec: A division of some PGM into free and clamped blocks.
  • samplers: A sampler to use to update every free block in gibbs_spec.
  • per_block_interactions: All the interactions that touch each free block in gibbs_spec.
  • per_block_interaction_active: indicates which interactions are real and which interactions are not part of the model and have been added to pad data structures so that they can be rectangular.
  • per_block_interaction_global_inds: how to find the information required to update each block within the global state list
  • per_block_interaction_global_slices: how to slice each array in the global state list to find the information required to update each block
  • _block_sd_inds: precomputed sd_index for each free block (avoids recomputing inside scan)
  • _block_positions: precomputed node positions in global state for each free block (avoids recomputing inside scan)
  • _block_output_sds: precomputed output ShapeDtypeStruct pytree for each free block
__abstractclassvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractmethods__ class-attribute #

Build an immutable unordered collection of unique elements.

__abstractvars__ class-attribute #

Build an immutable unordered collection of unique elements.

__annotations__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_fields__ class-attribute #

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

__dataclass_params__ class-attribute #
__doc__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__firstlineno__ class-attribute #

int([x]) -> integer int(x, base=10) -> integer

Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.

If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.

int('0b100', base=0) 4

__match_args__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__module__ class-attribute #

str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.

__static_attributes__ class-attribute #

Built-in immutable sequence.

If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.

If the argument is a tuple, the return value is the same object.

__init__(gibbs_spec: BlockGibbsSpec, samplers: list[hamon.conditional_samplers.AbstractConditionalSampler], interaction_groups: list[hamon.interaction.InteractionGroup]) #

Construct a BlockSamplingProgram.

Takes in a set of information that implicitly defines a sampling program and manipulates it into a shape appropriate for vectorized block-sampling. This involves reindexing, slicing, and often padding.

Arguments:

  • gibbs_spec: A division of some PGM into free and clamped blocks.
  • samplers: The update rule to use for each free block in gibbs_spec.
  • interaction_groups: A list of InteractionGroups that define how the variables in your sampling program affect one another.
with_ebm(ebm) -> BlockSamplingProgram #

Return a copy of this program rewired to a different EBM.

Subclasses that want to work with nrpt_adaptive(ebm=..., program=...) must override this method.

optimize_schedule(rejection_rates: jax.Array, betas: jax.Array) -> jax.Array #

Equalize per-pair rejection rates by redistributing β values.

nrpt(key: jax.Array, ebms: Sequence[AbstractEBM], programs: Sequence[BlockSamplingProgram], init_states: Sequence[list], clamp_state: list, n_rounds: int, gibbs_steps_per_round: int, betas: jax.Array | None = None, track_round_trips: bool = True, energy_delta_fn: Callable | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict] #

Non-Reversible Parallel Tempering with vectorized swaps.

Single-pass DEO: one swap parity per round, alternating even/odd. Multi-pass breaks non-reversibility (even∘odd∘odd∘even = identity).

Chains are ordered by ascending β: index 0 is the hottest chain (lowest β, closest to the reference distribution) and index −1 is the coldest chain (highest β, the target distribution you want to sample from). The returned states list preserves this ordering.

.. warning:: To collect samples from the target distribution, always use states[-1] (the cold chain), not states[0].

Stats keys

accepted, attempted, acceptance_rate, rejection_rates, betas round_trip_diagnostics (if track_round_trips=True): Lambda, tau_predicted, tau_observed, efficiency, lambda_profile, round_trips_per_chain, restarts_per_chain observations (if observer is not None): Per-round observer output stacked along axis 0. observer_carry (if observer is not None): Final observer carry after all rounds.

nrpt_adaptive(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_states: Sequence[list] = (), clamp_state: list | None = None, n_rounds: int = 0, gibbs_steps_per_round: int = 0, initial_betas: jax.Array | None = None, n_tune: int = 5, rounds_per_tune: int = 200, track_round_trips: bool = True, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict] #

NRPT with iterative schedule optimization (Algorithm 4).

Runs n_tune adaptation phases, each of rounds_per_tune rounds, updating the β schedule after each phase. Then runs the final n_rounds production phase with the optimized schedule.

Instead of providing ebm_factory and program_factory, you can pass a template ebm and program and the factories will be built internally using ebm.with_beta() and program.with_ebm().

Returns (states, stats) where stats includes tuning history in stats["tuning_history"]. States are ordered by ascending β — the cold chain (target distribution) is states[-1].

discover_chain_count(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_factory: Callable | None = None, clamp_state: list | None = None, beta_range: tuple[float, float] = (0.0, 1.0), gibbs_steps_per_round: int = 0, initial_n: int = 8, target_acceptance: float = 0.6, rounds_per_probe: int = 200, n_tune_per_probe: int = 4, max_iters: int = 6, min_chains: int = 3, max_chains: int = 128, lambda_rtol: float = 0.05, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None) -> dict #

Iteratively discover the right chain count for a given target acceptance.

The bootstrapping problem: Λ estimated with too few chains is biased low because the schedule can't resolve the peak in λ(β). Each iteration:

  1. Build N chains, run a short schedule optimization to estimate Λ.
  2. Update the running max-Λ (conservative: never underestimate).
  3. Compute N_rec from max-Λ, step halfway toward it.
  4. Stop when EITHER:
  5. N_rec ≈ N (chain count converged), OR
  6. Λ has stabilized (|ΔΛ/Λ| < lambda_rtol for 2 consecutive iters)

Using max-Λ prevents the "overshoot then drop" pattern where a noisy high estimate at iteration k inflates the recommendation, then a lower estimate at k+1 can't undo the damage. Stabilization detection catches the case where Λ is already well-resolved but N_rec still differs from N by a few chains.

Instead of providing ebm_factory and program_factory, you can pass a template ebm and program and the factories will be built internally using ebm.with_beta() and program.with_ebm(). init_factory is still required as initialization varies by use case.

Parameters:

Name Type Description Default
key Array

PRNG key

required
ebm_factory Callable | None

betas_array → list[EBM]

None
program_factory Callable | None

list[EBM] → list[Program]

None
init_factory Callable | None

(n_chains, list[EBM], list[Program]) → list[init_states]. Receives EBMs and programs so it can extract the correct free_blocks for initialization (block nodes must be the same objects as the EBMs' nodes).

None
clamp_state list | None

clamped block states

None
beta_range tuple[float, float]

(β_min, β_max) for the temperature range

(0.0, 1.0)
gibbs_steps_per_round int

Gibbs sweeps between swap attempts

0
initial_n int

starting chain count

8
target_acceptance float

desired per-pair swap acceptance rate

0.6
rounds_per_probe int

rounds for the final production probe

200
n_tune_per_probe int

schedule tuning iterations for the final probe

4
max_iters int

maximum discovery iterations

6
min_chains int

floor on chain count

3
max_chains int

ceiling on chain count

128
lambda_rtol float

relative tolerance for Λ stabilization (default 5%)

0.05

Returns:

Type Description
dict

dict with keys: n_chains: final recommended chain count betas: optimized schedule at that chain count Lambda: conservative (max) barrier estimate Lambda_raw: last raw estimate (may be lower than Lambda) target_acceptance: the target used converged_reason: "chain_count" | "lambda_stable" | "no_progress" | "max_iters" history: list of per-iteration dicts

init_index_state(n_chains: int) -> dict #

Initialize index process tracking arrays.

machine_to_chain[j] = which chain position machine j's state currently occupies. Initially machine j is at chain j.

visited_top[j] = whether machine j has reached chain N since its last round trip completion.

Returns a dict suitable for inclusion in lax.scan carry.

round_trip_summary(index_state: dict, rejection_rates: jax.Array, betas: jax.Array, n_rounds: int) -> dict #

Compute full diagnostic summary for NRPT run.

Returns dict with

Lambda: global communication barrier estimate tau_predicted: theoretical optimal round trip rate tau_observed: empirical round trip rate efficiency: tau_observed / tau_predicted (closer to 1 = better) lambda_profile: local barrier at each pair midpoint round_trips_per_chain: per-machine round trip counts restarts_per_chain: per-machine restart counts

update_index_state(index_state: dict, perm: jax.Array, n_chains: int) -> dict #

Update the index process after a swap pass.

Parameters:

Name Type Description Default
index_state dict

current tracking dict

required
perm Array

(n_chains,) int array — permutation applied to states

required
n_chains int

total number of chains

required

hamon.nrpt_adaptive(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_states: Sequence[list] = (), clamp_state: list | None = None, n_rounds: int = 0, gibbs_steps_per_round: int = 0, initial_betas: jax.Array | None = None, n_tune: int = 5, rounds_per_tune: int = 200, track_round_trips: bool = True, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict] #

NRPT with iterative schedule optimization (Algorithm 4).

Runs n_tune adaptation phases, each of rounds_per_tune rounds, updating the β schedule after each phase. Then runs the final n_rounds production phase with the optimized schedule.

Instead of providing ebm_factory and program_factory, you can pass a template ebm and program and the factories will be built internally using ebm.with_beta() and program.with_ebm().

Returns (states, stats) where stats includes tuning history in stats["tuning_history"]. States are ordered by ascending β — the cold chain (target distribution) is states[-1].

hamon.optimize_schedule(rejection_rates: jax.Array, betas: jax.Array) -> jax.Array #

Equalize per-pair rejection rates by redistributing β values.

hamon.discover_chain_count(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_factory: Callable | None = None, clamp_state: list | None = None, beta_range: tuple[float, float] = (0.0, 1.0), gibbs_steps_per_round: int = 0, initial_n: int = 8, target_acceptance: float = 0.6, rounds_per_probe: int = 200, n_tune_per_probe: int = 4, max_iters: int = 6, min_chains: int = 3, max_chains: int = 128, lambda_rtol: float = 0.05, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None) -> dict #

Iteratively discover the right chain count for a given target acceptance.

The bootstrapping problem: Λ estimated with too few chains is biased low because the schedule can't resolve the peak in λ(β). Each iteration:

  1. Build N chains, run a short schedule optimization to estimate Λ.
  2. Update the running max-Λ (conservative: never underestimate).
  3. Compute N_rec from max-Λ, step halfway toward it.
  4. Stop when EITHER:
  5. N_rec ≈ N (chain count converged), OR
  6. Λ has stabilized (|ΔΛ/Λ| < lambda_rtol for 2 consecutive iters)

Using max-Λ prevents the "overshoot then drop" pattern where a noisy high estimate at iteration k inflates the recommendation, then a lower estimate at k+1 can't undo the damage. Stabilization detection catches the case where Λ is already well-resolved but N_rec still differs from N by a few chains.

Instead of providing ebm_factory and program_factory, you can pass a template ebm and program and the factories will be built internally using ebm.with_beta() and program.with_ebm(). init_factory is still required as initialization varies by use case.

Parameters:

Name Type Description Default
key Array

PRNG key

required
ebm_factory Callable | None

betas_array → list[EBM]

None
program_factory Callable | None

list[EBM] → list[Program]

None
init_factory Callable | None

(n_chains, list[EBM], list[Program]) → list[init_states]. Receives EBMs and programs so it can extract the correct free_blocks for initialization (block nodes must be the same objects as the EBMs' nodes).

None
clamp_state list | None

clamped block states

None
beta_range tuple[float, float]

(β_min, β_max) for the temperature range

(0.0, 1.0)
gibbs_steps_per_round int

Gibbs sweeps between swap attempts

0
initial_n int

starting chain count

8
target_acceptance float

desired per-pair swap acceptance rate

0.6
rounds_per_probe int

rounds for the final production probe

200
n_tune_per_probe int

schedule tuning iterations for the final probe

4
max_iters int

maximum discovery iterations

6
min_chains int

floor on chain count

3
max_chains int

ceiling on chain count

128
lambda_rtol float

relative tolerance for Λ stabilization (default 5%)

0.05

Returns:

Type Description
dict

dict with keys: n_chains: final recommended chain count betas: optimized schedule at that chain count Lambda: conservative (max) barrier estimate Lambda_raw: last raw estimate (may be lower than Lambda) target_acceptance: the target used converged_reason: "chain_count" | "lambda_stable" | "no_progress" | "max_iters" history: list of per-iteration dicts