NRPT#
Core NRPT functions: single runs, adaptive schedule tuning, schedule optimization, and automatic chain count discovery.
hamon.nrpt
#
Non-Reversible Parallel Tempering with vectorized swaps.
Based on Syed et al. (2021), "Non-Reversible Parallel Tempering: a Scalable Highly Parallel MCMC Scheme" (arXiv:1905.02939).
Exploits temperature-linearity (E_β = β·E_base) for single-eval-per-chain swap decisions. Adaptive schedule optimization (Algorithm 4) equalizes rejection rates. Optional energy caching with boundary-only deltas for rectangular block partitions.
__cached__
module-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__doc__
module-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__file__
module-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__name__
module-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__package__
module-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
annotations
module-attribute
#
NRPTCarry
#
Scan carry for the NRPT inner loop.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__orig_bases__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__slots__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
accepted
class-attribute
#
Alias for field number 2
attempted
class-attribute
#
Alias for field number 3
base_E
class-attribute
#
Alias for field number 5
idx_state
class-attribute
#
Alias for field number 4
key
class-attribute
#
Alias for field number 0
obs_carry
class-attribute
#
Alias for field number 6
states
class-attribute
#
Alias for field number 1
AbstractEBM
#
Something that has a well-defined energy function (map from a state to a scalar).
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
energy(state: list[PyTree[Shaped[Array, 'nodes ?*state'], _State]], blocks: BlockSpec | list[Block]) -> Float[Array, '']
#
Evaluate the energy function of the EBM given some state information.
Arguments:
state: The state for which to evaluate the energy function. Must be compatible withblocks.blocks: Specifies how the information instateis organized. May be either a pre-builtBlockSpec(fast path — avoids rebuilding the spec) or a plainlist[Block]for convenience when calling from user code.
Returns:
A scalar representing the energy value associated with state.
with_beta(beta: Array) -> AbstractEBM
#
Return a copy of this EBM with a different inverse-temperature β.
Subclasses that want to work with nrpt_adaptive(ebm=..., program=...)
must override this method.
AbstractNRPTObserver
#
Observer for NRPT rounds, called once per round after Gibbs sweeps and swaps.
Concrete subclasses must implement __call__ and may override init
to provide a non-trivial carry. The observation returned by
__call__ is stacked by lax.scan into a pytree with a leading axis
of size n_rounds. Return None as the observation for
accumulate-only observers that do not need per-round storage.
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__call__(stacked_states: list[Array], base_energies: Array, round_idx: Int[Array, ''], carry: ~ObserveCarry) -> tuple[~ObserveCarry, PyTree]
#
Observe one NRPT round.
Arguments:
stacked_states: Per-block arrays, each of shape(n_chains, ...). The cold chain (target) is at index-1.base_energies: Shape(n_chains,)base energies (no β factor).round_idx: Zero-based round counter.carry: Arbitrary pytree state threaded across rounds.
Returns:
(updated_carry, observation) — observation is stacked by
lax.scan; use None for accumulate-only mode.
__init__() -> None
#
Initialize self. See help(type(self)) for accurate signature.
init() -> PyTree
#
Initialize the observer carry. Defaults to None.
BlockSamplingProgram
#
A PGM block-sampling program.
This class encapsulates everything that is needed to run a PGM block sampling program in hamon.
per_block_interactions and per_block_interaction_active are parallel to the free blocks in gibbs_spec, and
their members are passed directly to a sampler when the state of the corresponding free block is being updated
during a sampling program. per_block_interaction_global_inds and per_block_interaction_global_slices are
also parallel to the free blocks, and are used to slice the global state of the program to produce the
state information required to update the state of each block alongside the static information contained in the
interactions.
Attributes:
gibbs_spec: A division of some PGM into free and clamped blocks.samplers: A sampler to use to update every free block ingibbs_spec.per_block_interactions: All the interactions that touch each free block ingibbs_spec.per_block_interaction_active: indicates which interactions are real and which interactions are not part of the model and have been added to pad data structures so that they can be rectangular.per_block_interaction_global_inds: how to find the information required to update each block within the global state listper_block_interaction_global_slices: how to slice each array in the global state list to find the information required to update each block_block_sd_inds: precomputed sd_index for each free block (avoids recomputing inside scan)_block_positions: precomputed node positions in global state for each free block (avoids recomputing inside scan)_block_output_sds: precomputed output ShapeDtypeStruct pytree for each free block
__abstractclassvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractmethods__
class-attribute
#
Build an immutable unordered collection of unique elements.
__abstractvars__
class-attribute
#
Build an immutable unordered collection of unique elements.
__annotations__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_fields__
class-attribute
#
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)
__dataclass_params__
class-attribute
#
__doc__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__firstlineno__
class-attribute
#
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
__match_args__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__module__
class-attribute
#
str(object='') -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.str() (if defined) or repr(object). encoding defaults to 'utf-8'. errors defaults to 'strict'.
__static_attributes__
class-attribute
#
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
__init__(gibbs_spec: BlockGibbsSpec, samplers: list[hamon.conditional_samplers.AbstractConditionalSampler], interaction_groups: list[hamon.interaction.InteractionGroup])
#
Construct a BlockSamplingProgram.
Takes in a set of information that implicitly defines a sampling program and manipulates it into a shape appropriate for vectorized block-sampling. This involves reindexing, slicing, and often padding.
Arguments:
gibbs_spec: A division of some PGM into free and clamped blocks.samplers: The update rule to use for each free block ingibbs_spec.interaction_groups: A list ofInteractionGroupsthat define how the variables in your sampling program affect one another.
with_ebm(ebm) -> BlockSamplingProgram
#
Return a copy of this program rewired to a different EBM.
Subclasses that want to work with nrpt_adaptive(ebm=..., program=...)
must override this method.
optimize_schedule(rejection_rates: jax.Array, betas: jax.Array) -> jax.Array
#
Equalize per-pair rejection rates by redistributing β values.
nrpt(key: jax.Array, ebms: Sequence[AbstractEBM], programs: Sequence[BlockSamplingProgram], init_states: Sequence[list], clamp_state: list, n_rounds: int, gibbs_steps_per_round: int, betas: jax.Array | None = None, track_round_trips: bool = True, energy_delta_fn: Callable | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict]
#
Non-Reversible Parallel Tempering with vectorized swaps.
Single-pass DEO: one swap parity per round, alternating even/odd. Multi-pass breaks non-reversibility (even∘odd∘odd∘even = identity).
Chains are ordered by ascending β: index 0 is the hottest chain
(lowest β, closest to the reference distribution) and index −1 is the
coldest chain (highest β, the target distribution you want to
sample from). The returned states list preserves this ordering.
.. warning::
To collect samples from the target distribution, always use
states[-1] (the cold chain), not states[0].
Stats keys
accepted, attempted, acceptance_rate, rejection_rates, betas round_trip_diagnostics (if track_round_trips=True): Lambda, tau_predicted, tau_observed, efficiency, lambda_profile, round_trips_per_chain, restarts_per_chain observations (if observer is not None): Per-round observer output stacked along axis 0. observer_carry (if observer is not None): Final observer carry after all rounds.
nrpt_adaptive(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_states: Sequence[list] = (), clamp_state: list | None = None, n_rounds: int = 0, gibbs_steps_per_round: int = 0, initial_betas: jax.Array | None = None, n_tune: int = 5, rounds_per_tune: int = 200, track_round_trips: bool = True, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict]
#
NRPT with iterative schedule optimization (Algorithm 4).
Runs n_tune adaptation phases, each of rounds_per_tune rounds, updating the β schedule after each phase. Then runs the final n_rounds production phase with the optimized schedule.
Instead of providing ebm_factory and program_factory, you can pass
a template ebm and program and the factories will be built
internally using ebm.with_beta() and program.with_ebm().
Returns (states, stats) where stats includes tuning history in
stats["tuning_history"]. States are ordered by ascending β — the
cold chain (target distribution) is states[-1].
discover_chain_count(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_factory: Callable | None = None, clamp_state: list | None = None, beta_range: tuple[float, float] = (0.0, 1.0), gibbs_steps_per_round: int = 0, initial_n: int = 8, target_acceptance: float = 0.6, rounds_per_probe: int = 200, n_tune_per_probe: int = 4, max_iters: int = 6, min_chains: int = 3, max_chains: int = 128, lambda_rtol: float = 0.05, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None) -> dict
#
Iteratively discover the right chain count for a given target acceptance.
The bootstrapping problem: Λ estimated with too few chains is biased low because the schedule can't resolve the peak in λ(β). Each iteration:
- Build N chains, run a short schedule optimization to estimate Λ.
- Update the running max-Λ (conservative: never underestimate).
- Compute N_rec from max-Λ, step halfway toward it.
- Stop when EITHER:
- N_rec ≈ N (chain count converged), OR
- Λ has stabilized (|ΔΛ/Λ| < lambda_rtol for 2 consecutive iters)
Using max-Λ prevents the "overshoot then drop" pattern where a noisy high estimate at iteration k inflates the recommendation, then a lower estimate at k+1 can't undo the damage. Stabilization detection catches the case where Λ is already well-resolved but N_rec still differs from N by a few chains.
Instead of providing ebm_factory and program_factory, you can pass
a template ebm and program and the factories will be built
internally using ebm.with_beta() and program.with_ebm().
init_factory is still required as initialization varies by use case.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
PRNG key |
required |
ebm_factory
|
Callable | None
|
betas_array → list[EBM] |
None
|
program_factory
|
Callable | None
|
list[EBM] → list[Program] |
None
|
init_factory
|
Callable | None
|
(n_chains, list[EBM], list[Program]) → list[init_states]. Receives EBMs and programs so it can extract the correct free_blocks for initialization (block nodes must be the same objects as the EBMs' nodes). |
None
|
clamp_state
|
list | None
|
clamped block states |
None
|
beta_range
|
tuple[float, float]
|
(β_min, β_max) for the temperature range |
(0.0, 1.0)
|
gibbs_steps_per_round
|
int
|
Gibbs sweeps between swap attempts |
0
|
initial_n
|
int
|
starting chain count |
8
|
target_acceptance
|
float
|
desired per-pair swap acceptance rate |
0.6
|
rounds_per_probe
|
int
|
rounds for the final production probe |
200
|
n_tune_per_probe
|
int
|
schedule tuning iterations for the final probe |
4
|
max_iters
|
int
|
maximum discovery iterations |
6
|
min_chains
|
int
|
floor on chain count |
3
|
max_chains
|
int
|
ceiling on chain count |
128
|
lambda_rtol
|
float
|
relative tolerance for Λ stabilization (default 5%) |
0.05
|
Returns:
| Type | Description |
|---|---|
dict
|
dict with keys: n_chains: final recommended chain count betas: optimized schedule at that chain count Lambda: conservative (max) barrier estimate Lambda_raw: last raw estimate (may be lower than Lambda) target_acceptance: the target used converged_reason: "chain_count" | "lambda_stable" | "no_progress" | "max_iters" history: list of per-iteration dicts |
init_index_state(n_chains: int) -> dict
#
Initialize index process tracking arrays.
machine_to_chain[j] = which chain position machine j's state
currently occupies. Initially machine j is at chain j.
visited_top[j] = whether machine j has reached chain N since
its last round trip completion.
Returns a dict suitable for inclusion in lax.scan carry.
round_trip_summary(index_state: dict, rejection_rates: jax.Array, betas: jax.Array, n_rounds: int) -> dict
#
Compute full diagnostic summary for NRPT run.
Returns dict with
Lambda: global communication barrier estimate tau_predicted: theoretical optimal round trip rate tau_observed: empirical round trip rate efficiency: tau_observed / tau_predicted (closer to 1 = better) lambda_profile: local barrier at each pair midpoint round_trips_per_chain: per-machine round trip counts restarts_per_chain: per-machine restart counts
update_index_state(index_state: dict, perm: jax.Array, n_chains: int) -> dict
#
Update the index process after a swap pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index_state
|
dict
|
current tracking dict |
required |
perm
|
Array
|
(n_chains,) int array — permutation applied to states |
required |
n_chains
|
int
|
total number of chains |
required |
hamon.nrpt_adaptive(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_states: Sequence[list] = (), clamp_state: list | None = None, n_rounds: int = 0, gibbs_steps_per_round: int = 0, initial_betas: jax.Array | None = None, n_tune: int = 5, rounds_per_tune: int = 200, track_round_trips: bool = True, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None, observer: AbstractNRPTObserver | None = None) -> tuple[list, dict]
#
NRPT with iterative schedule optimization (Algorithm 4).
Runs n_tune adaptation phases, each of rounds_per_tune rounds, updating the β schedule after each phase. Then runs the final n_rounds production phase with the optimized schedule.
Instead of providing ebm_factory and program_factory, you can pass
a template ebm and program and the factories will be built
internally using ebm.with_beta() and program.with_ebm().
Returns (states, stats) where stats includes tuning history in
stats["tuning_history"]. States are ordered by ascending β — the
cold chain (target distribution) is states[-1].
hamon.optimize_schedule(rejection_rates: jax.Array, betas: jax.Array) -> jax.Array
#
Equalize per-pair rejection rates by redistributing β values.
hamon.discover_chain_count(key: jax.Array, ebm_factory: Callable | None = None, program_factory: Callable | None = None, init_factory: Callable | None = None, clamp_state: list | None = None, beta_range: tuple[float, float] = (0.0, 1.0), gibbs_steps_per_round: int = 0, initial_n: int = 8, target_acceptance: float = 0.6, rounds_per_probe: int = 200, n_tune_per_probe: int = 4, max_iters: int = 6, min_chains: int = 3, max_chains: int = 128, lambda_rtol: float = 0.05, *, ebm: AbstractEBM | None = None, program: BlockSamplingProgram | None = None) -> dict
#
Iteratively discover the right chain count for a given target acceptance.
The bootstrapping problem: Λ estimated with too few chains is biased low because the schedule can't resolve the peak in λ(β). Each iteration:
- Build N chains, run a short schedule optimization to estimate Λ.
- Update the running max-Λ (conservative: never underestimate).
- Compute N_rec from max-Λ, step halfway toward it.
- Stop when EITHER:
- N_rec ≈ N (chain count converged), OR
- Λ has stabilized (|ΔΛ/Λ| < lambda_rtol for 2 consecutive iters)
Using max-Λ prevents the "overshoot then drop" pattern where a noisy high estimate at iteration k inflates the recommendation, then a lower estimate at k+1 can't undo the damage. Stabilization detection catches the case where Λ is already well-resolved but N_rec still differs from N by a few chains.
Instead of providing ebm_factory and program_factory, you can pass
a template ebm and program and the factories will be built
internally using ebm.with_beta() and program.with_ebm().
init_factory is still required as initialization varies by use case.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Array
|
PRNG key |
required |
ebm_factory
|
Callable | None
|
betas_array → list[EBM] |
None
|
program_factory
|
Callable | None
|
list[EBM] → list[Program] |
None
|
init_factory
|
Callable | None
|
(n_chains, list[EBM], list[Program]) → list[init_states]. Receives EBMs and programs so it can extract the correct free_blocks for initialization (block nodes must be the same objects as the EBMs' nodes). |
None
|
clamp_state
|
list | None
|
clamped block states |
None
|
beta_range
|
tuple[float, float]
|
(β_min, β_max) for the temperature range |
(0.0, 1.0)
|
gibbs_steps_per_round
|
int
|
Gibbs sweeps between swap attempts |
0
|
initial_n
|
int
|
starting chain count |
8
|
target_acceptance
|
float
|
desired per-pair swap acceptance rate |
0.6
|
rounds_per_probe
|
int
|
rounds for the final production probe |
200
|
n_tune_per_probe
|
int
|
schedule tuning iterations for the final probe |
4
|
max_iters
|
int
|
maximum discovery iterations |
6
|
min_chains
|
int
|
floor on chain count |
3
|
max_chains
|
int
|
ceiling on chain count |
128
|
lambda_rtol
|
float
|
relative tolerance for Λ stabilization (default 5%) |
0.05
|
Returns:
| Type | Description |
|---|---|
dict
|
dict with keys: n_chains: final recommended chain count betas: optimized schedule at that chain count Lambda: conservative (max) barrier estimate Lambda_raw: last raw estimate (may be lower than Lambda) target_acceptance: the target used converged_reason: "chain_count" | "lambda_stable" | "no_progress" | "max_iters" history: list of per-iteration dicts |