# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""WHAM crossing-probability analysis for infinite-swapping RETIS.
This module computes per-interface and total crossing probabilities, the
initial flux, and the rate constant from the path-data file written by
the infinite-swapping (replica-exchange) sampler
(:func:`pyretis.simulation.repex.write_path_ensemble_data`; the file is
named ``infswap_data.txt``).
The estimator follows the standard weighted-histogram (WHAM)
crossing-probability procedure for transition interface sampling: the
high-acceptance (HA) weight unweighting, the ``eta`` normalisation, the
Q-factor WHAM stitch (Lervik et al., J. Comput. Chem. 2015), a
point-matching cross-check and a block-error analysis. The
implementation returns its results as values rather than writing report
files.
Data-file column layout (after ``str.split`` on a data line):
* ``0`` -- path index
* ``1`` -- path length
* ``2`` -- maximum order parameter (``lambda_max``)
* ``3 .. 3 + nintf - 1`` -- ``Cxy`` (fractional sampling occurrence) for
ensembles ``[0-]``, ``[0+]``, ``[1+]``, ...
* ``3 + nintf .. 3 + 2*nintf - 1`` -- the high-acceptance (HA) weights for
the same ensembles.
Here ``nintf`` is the number of interfaces and ``i0plus = 4`` is the column
of the ``[0+]`` ensemble.
"""
import logging
import os
import numpy as np
from pyretis.core.engine_time import engine_time_per_step
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
# Column of the [0+] ensemble in a split data line (column 0 is the path
# index). [0-] sits one column to the left at ``I0PLUS - 1``.
I0PLUS = 4
[docs]def read_data_matrix(filename, nskip=0):
"""Read an ``infswap_data.txt`` file into a numeric matrix.
Parameters
----------
filename : str
Path to the ``infswap_data.txt`` file.
nskip : int, optional
Number of initial data rows to discard (loaded paths /
equilibration).
Returns
-------
matrix : list of list of float
One row per path; ``"----"`` entries are mapped to ``0.0``.
"""
matrix = []
with open(filename, encoding='utf-8') as infile:
for line in infile:
if line.startswith('#'):
continue
stripped = line.strip()
if not stripped:
continue
values = [
float(item) if item != '----' else 0.0
for item in stripped.split()
]
matrix.append(values)
del matrix[:nskip]
return matrix
[docs]def rec_block_errors(runav, minblocks):
"""Block-error analysis of a running-average series.
Standard recursive block-error analysis of the running average.
Parameters
----------
runav : sequence of float
The running average of an observable, one entry per path.
minblocks : int
Minimum number of blocks in the block-error analysis.
Returns
-------
half_av_err : float
Average relative error over the second half of the block lengths.
n_stat_ineff : float
Statistical inefficiency estimate.
rel_errors : list of float
Relative error as a function of block length.
"""
runav = np.asarray(runav, dtype=float)
if len(runav) < 2 * minblocks:
return 0.0, 0.0, []
maxbll = int(len(runav) / minblocks) # maximum block length
bestav = runav[-1] # most accurate average we have
if bestav == 0.0:
return 0.0, 0.0, []
rel_errors = []
for blocklength in range(1, maxbll + 1):
runav_red = runav[blocklength - 1::blocklength]
blocks = _rec_blocks(runav_red)
sum_qudiff = np.sum((blocks - bestav) ** 2)
n_blocks = len(blocks)
if n_blocks < 2:
continue
abs_err = np.sqrt(sum_qudiff / (n_blocks * (n_blocks - 1)))
rel_errors.append(abs_err / bestav)
if not rel_errors:
return 0.0, 0.0, []
second_half = rel_errors[len(rel_errors) // 2:]
half_av_err = float(np.mean(second_half))
n_stat_ineff = (half_av_err / rel_errors[0]) ** 2
return half_av_err, n_stat_ineff, rel_errors
[docs]def _rec_blocks(reduced):
"""Recover block averages from a reduced running-average array."""
reduced = np.asarray(reduced, dtype=float)
n_red = len(reduced)
result = np.zeros(n_red, dtype=reduced.dtype)
result += np.arange(1, n_red + 1) * reduced
result[1:] -= np.arange(1, n_red) * reduced[:-1]
return result
[docs]def _unweight_matrix(matrix, nintf):
"""Unweight the ``Cxy`` columns by the HA-weights, in place.
Standard WHAM unweighting step: each
``Cxy`` value is divided by its HA-weight, the (now redundant)
HA-weight column is overwritten with the running sum of the
pre-unweighting ``Cxy`` (needed for the running WHAM averages), and
finally each ``Cxy`` column is divided by the average inverse
HA-weight so the ``eta`` values are comparable across ensembles that
use different moves (e.g. ``[0+]`` shooting vs ``[i+]`` wire fencing).
Parameters
----------
matrix : list of list of float
The data matrix from :func:`read_data_matrix`; modified in place.
nintf : int
Number of interfaces.
"""
i0min = I0PLUS - 1
sum_pxy = [0.0] * nintf # sum of Cxy before unweighting
sum_pxy_afterw = [0.0] * nintf # sum of Cxy after unweighting
for row in matrix:
for ens in range(nintf):
cxy_idx = i0min + ens
weight_idx = cxy_idx + nintf
if row[weight_idx] > 0: # non-zero HA-weight
sum_pxy[ens] += row[cxy_idx]
row[cxy_idx] /= row[weight_idx]
sum_pxy_afterw[ens] += row[cxy_idx]
elif row[cxy_idx] > 0:
raise ValueError(
f'Division by zero HA-weight for path row {row}.'
)
# Overwrite the HA-weight column with the running sum of the
# pre-unweighting Cxy; the HA-weights are no longer needed.
row[weight_idx] = sum_pxy[ens]
for ens in range(nintf):
if sum_pxy[ens] == 0.0:
continue
av_inv_w = sum_pxy_afterw[ens] / sum_pxy[ens]
cxy_idx = i0min + ens
for row in matrix:
row[cxy_idx] /= av_inv_w
[docs]def _wham_ptot_run(interfaces, ploc_matrix, sum_pxy):
"""Running-average total crossing probability via WHAM.
``ploc_matrix[j][i]`` is ``P_A(lambda_i | lambda_j)`` from the
``[j+]`` ensemble using the data so far.
"""
n_ens = len(interfaces) - 1 # lambda_n == lambda_B
q_array = [0.0] * n_ens
prob = 1.0
inv_q = sum_pxy[0] * prob
if inv_q == 0:
return 0.0, q_array
q_fac = 1.0 / inv_q
q_array[0] = q_fac
for i in range(1, n_ens + 1):
nominator = 0.0
for j in range(i):
nominator += sum_pxy[j] * ploc_matrix[j][i]
prob = nominator * q_fac
if i == n_ens or prob == 0:
return prob, q_array
inv_q += sum_pxy[i] / prob
q_fac = 1.0 / inv_q
q_array[i] = q_fac
return prob, q_array
[docs]def _wham_pq(n_plus_ens, interfaces, lamres, eta, v_alpha):
"""WHAM crossing probabilities at the interfaces and Q-factors.
WHAM crossing probabilities at the interfaces and Q-factors (Q-factor
recursion of Lervik et al., JCTC 2015). Returns ``P`` (crossing
probability at each interface, with ``lambda_B`` appended) and ``Q``
(the per-ensemble normalisation factors for ``v_alpha``).
"""
prob = [0.0] * n_plus_ens
q_fac = [0.0] * n_plus_ens
inv_q = [0.0] * n_plus_ens
prob[0] = 1.0
inv_q[0] = eta[0]
if inv_q[0] == 0:
# No [0+] sampling at all: every deeper interface is unreachable,
# so its crossing probability is zero. Append the lambda_B slot
# (as the normal path does) so the returned list always has one
# entry per interface -- never nintf - 1.
prob.append(0.0)
return prob, q_fac
q_fac[0] = 1.0 / inv_q[0]
lambda_a = interfaces[0]
for i in range(1, n_plus_ens):
alpha = round((interfaces[i] - lambda_a) / lamres)
prob[i] = v_alpha[alpha] * q_fac[i - 1]
if prob[i] == 0:
# Nothing reaches interface i; the deeper probabilities are
# already 0. Pad the lambda_B slot so len(prob) == nintf, as
# on the normal exit below.
prob.append(0.0)
return prob, q_fac
inv_q[i] = inv_q[i - 1] + (eta[i] / prob[i])
q_fac[i] = 1.0 / inv_q[i]
prob.append(v_alpha[-1] * q_fac[n_plus_ens - 1])
return prob, q_fac
[docs]def _default_lamres(interfaces):
"""Pick a default order-parameter resolution from the interfaces.
The WHAM stitch places every interface on a uniform grid of step
``lamres`` via ``round((lambda - lambda_A) / lamres)``. To keep
adjacent interfaces on *distinct* grid points the step must be small
relative to the *smallest* interface spacing -- not the first one.
Using the first gap (the historical default) silently collapses two
interfaces onto the same grid index whenever a later gap is smaller,
which corrupts the per-interface crossing probabilities. We therefore
take one tenth of the smallest spacing.
Parameters
----------
interfaces : list of float
Strictly increasing interface positions.
Returns
-------
lamres : float
The chosen resolution.
"""
gaps = [
interfaces[i + 1] - interfaces[i]
for i in range(len(interfaces) - 1)
]
min_gap = min(gaps)
if min_gap <= 0.0:
raise ValueError(
'Interfaces must be strictly increasing; found a '
f'non-positive spacing in {interfaces}.'
)
return min_gap / 10.0
[docs]def _validate_lamres(interfaces, lamres):
"""Check that ``lamres`` resolves every interface separately.
Raises if two interfaces round to the same grid index (the stitch
would double-count an ensemble) and warns if an interface does not
sit exactly on the grid (the per-interface crossing probability is
then read at the nearest grid point; the interface is mis-placed by
at most ``lamres / 2`` in lambda, and the probability error that
induces depends on the local crossing-probability slope). Failing
loud here is deliberate: a wrong crossing probability is worse than a
refusal to run.
Parameters
----------
interfaces : list of float
Strictly increasing interface positions.
lamres : float
Candidate order-parameter resolution.
"""
if lamres <= 0.0:
raise ValueError(f'lamres must be positive, got {lamres}.')
lambda_a = interfaces[0]
indices = []
for lam in interfaces:
steps = (lam - lambda_a) / lamres
nearest = round(steps)
indices.append(nearest)
if abs(steps - nearest) > 1.0e-6 * max(1.0, abs(steps)):
logger.warning(
'Interface %s is not commensurate with lamres=%s: '
'(lambda - lambda_A)/lamres = %s is not an integer, so '
'its crossing probability is read at the nearest grid '
'point. Use a lamres that evenly divides every interface '
'spacing for exact placement.', lam, lamres, steps,
)
if len(set(indices)) != len(indices):
raise ValueError(
f'lamres={lamres} is too coarse: two interfaces map to the '
'same grid point, which would corrupt the WHAM stitch. '
'Reduce lamres (it must be smaller than the closest '
'interface spacing).'
)
[docs]def wham_crossing_probability(matrix, interfaces, lamres=None,
minblocks=5, interval=1.0):
"""Compute WHAM crossing probabilities and rate from a data matrix.
Core WHAM crossing-probability, flux and rate computation (crossing
probability + flux + rate), with the file/plot output replaced by a
returned dictionary. The matrix is modified in place by the
HA-weight unweighting step.
Parameters
----------
matrix : list of list of float
Data matrix from :func:`read_data_matrix` (already ``nskip``-ped).
interfaces : sequence of float
The interface positions ``[lambda_0, ..., lambda_B]``.
lamres : float, optional
Order-parameter resolution. Defaults to one tenth of the
*smallest* interface spacing (see :func:`_default_lamres`); this
keeps non-uniform interfaces on distinct grid points. Whatever
value is used is validated by :func:`_validate_lamres`.
minblocks : int, optional
Minimum number of blocks for the block-error analysis.
interval : float, optional
Simulation time per recorded path step, i.e. ``timestep *
subcycles``. The ``[0-]``/``[0+]`` path lengths are in steps, so
the flux is divided by this to give a rate per unit time --
matching the native PyRETIS flux analysis
(:func:`pyretis.analysis.flux_analysis.analyse_flux`). The default
``1.0`` reproduces the upstream inftools convention of a rate per
step; pass the real interval to obtain a physical rate.
Returns
-------
results : dict
Keys: ``pcross_pm`` / ``pcross_wham`` (total crossing
probability, point-matching and WHAM), ``pcross_pm_relerr`` /
``pcross_wham_relerr`` (relative block errors), ``pcross_at_intf``
(WHAM crossing probability at each interface),
``pcross_at_intf_pm`` (point-matching version), ``flux``,
``rate_pm`` / ``rate_wham``, ``rate_pm_relerr``,
``length_0minus`` / ``length_0plus``, ``n_records``,
``lambda_values`` and ``pcross_curve_wham`` / ``pcross_curve_pm``
(the full ``P_A(lambda)`` profiles).
"""
if not matrix:
raise ValueError('Empty path-data matrix (all rows skipped?).')
interfaces = [float(item) for item in interfaces]
if len(interfaces) < 2:
raise ValueError(
f'Need at least two interfaces, got {interfaces}.'
)
if lamres is None:
lamres = _default_lamres(interfaces)
_validate_lamres(interfaces, lamres)
lambda_a = interfaces[0]
lambda_b = interfaces[-1]
nintf = len(interfaces)
n_plus_ens = nintf - 1
_unweight_matrix(matrix, nintf)
lambda_values = [
i * lamres
for i in range(round(lambda_a / lamres), round(lambda_b / lamres) + 1)
]
n_alpha = len(lambda_values)
v_alpha = [0.0] * n_alpha # WHAM total crossing probability
v_alpha[0] = 1.0
u_alpha = [0.0] * n_alpha # point-matching total crossing prob
u_alpha[0] = 1.0
p_loc = [[0.0] * n_alpha for _ in range(n_plus_ens)]
eta = [0.0] * n_plus_ens
run_av_ptot_wham = []
run_av_q = []
ploc_runav = [[0.0] * nintf for _ in range(nintf)]
for row in matrix:
lambdamax = row[2]
for i in range(n_plus_ens):
cxy = row[I0PLUS + i]
eta[i] += cxy
lambda_i = interfaces[i]
alpha_max = int(np.floor((lambdamax - lambda_a) / lamres))
alpha_min = round((lambda_i - lambda_a) / lamres)
if alpha_max > n_alpha - 1:
alpha_max = n_alpha - 1
for alpha in range(alpha_min, alpha_max + 1):
p_loc[i][alpha] += cxy
# v(alpha) at the next interface is set by the lower ensemble.
for alpha in range(alpha_min + 1, alpha_max + 1):
v_alpha[alpha] += cxy
ploc_runav[i][0:i] = [0.0] * i
for j in range(i, nintf):
alpha_j = round((interfaces[j] - lambda_a) / lamres)
ploc_runav[i][j] = (
p_loc[i][alpha_j] / eta[i] if eta[i] != 0.0 else 0.0
)
sum_pxy = row[-n_plus_ens:]
ptot_wham, q_array = _wham_ptot_run(interfaces, ploc_runav, sum_pxy)
run_av_ptot_wham.append(ptot_wham)
run_av_q.append(q_array)
# Point-matching running total: product of local crossing probs.
run_av_ptot_pm = _running_pm(matrix, interfaces, lamres)
# Final WHAM normalisation of v_alpha and the per-interface probs.
pcross_at_intf, q_fac = _wham_pq(
n_plus_ens, interfaces, lamres, eta, v_alpha
)
p_loc = [
[val / eta[i] if eta[i] != 0.0 else 0.0 for val in p_loc[i]]
for i in range(n_plus_ens)
]
for k in range(n_plus_ens):
alpha_min = round((interfaces[k] - lambda_a) / lamres) + 1
alpha_max = round((interfaces[k + 1] - lambda_a) / lamres)
for alpha in range(alpha_min, alpha_max + 1):
v_alpha[alpha] *= q_fac[k]
# Point-matching profile and per-interface probabilities.
pcross_at_intf_pm = [0.0] * nintf
pcross_at_intf_pm[0] = 1.0
for i in range(n_plus_ens):
alpha_min = round((interfaces[i] - lambda_a) / lamres) + 1
alpha_max = round((interfaces[i + 1] - lambda_a) / lamres)
u_alpha[alpha_min:alpha_max + 1] = [
pcross_at_intf_pm[i] * num
for num in p_loc[i][alpha_min:alpha_max + 1]
]
pcross_at_intf_pm[i + 1] = u_alpha[alpha_max]
# Initial flux from the [0-] and [0+] path lengths. Those lengths are
# in simulation steps, so multiply by ``interval`` (the time per step,
# ``timestep * subcycles``) to get a flux per unit time, matching the
# native PyRETIS flux convention. Upstream inftools omits this factor
# and so reports a rate per step.
length_0min, length_0plus = _running_lengths(matrix)
eff_time = (length_0min + length_0plus - 4) * interval
flux = 1.0 / eff_time if eff_time > 0 else 0.0
pcross_pm = run_av_ptot_pm[-1]
pcross_wham = run_av_ptot_wham[-1]
rate_pm = flux * pcross_pm
rate_wham = flux * pcross_wham
pm_err, _, _ = rec_block_errors(run_av_ptot_pm, minblocks)
wham_err, _, _ = rec_block_errors(run_av_ptot_wham, minblocks)
flux_series = _running_flux(matrix, interval)
runav_rate_pm = [
flux_t * p_t for flux_t, p_t in zip(flux_series, run_av_ptot_pm)
]
runav_rate_wham = [
flux_t * p_t for flux_t, p_t in zip(flux_series, run_av_ptot_wham)
]
rate_err, _, _ = rec_block_errors(runav_rate_pm, minblocks)
return {
'pcross_pm': pcross_pm,
'pcross_wham': pcross_wham,
'pcross_pm_relerr': pm_err,
'pcross_wham_relerr': wham_err,
'pcross_at_intf': pcross_at_intf,
'pcross_at_intf_pm': pcross_at_intf_pm,
'flux': flux,
'rate_pm': rate_pm,
'rate_wham': rate_wham,
'rate_pm_relerr': rate_err,
'length_0minus': length_0min,
'length_0plus': length_0plus,
'n_records': len(matrix),
# Per-path running-average series (for rate-vs-cycles plots).
'pcross_pm_series': run_av_ptot_pm,
'pcross_wham_series': run_av_ptot_wham,
'rate_pm_series': runav_rate_pm,
'rate_wham_series': runav_rate_wham,
'lambda_values': lambda_values,
'pcross_curve_wham': v_alpha,
'pcross_curve_pm': u_alpha,
}
[docs]def _running_pm(matrix, interfaces, lamres):
"""Compute the running point-matching total crossing probability.
Running point-matching total crossing probability: at each path
the product over ensembles of ``P_A(lambda_{i+1} | lambda_i)``
evaluated from the local crossing probabilities accumulated so far.
"""
lambda_a = interfaces[0]
nintf = len(interfaces)
n_plus_ens = nintf - 1
n_alpha = round((interfaces[-1] - lambda_a) / lamres) + 1
# Re-accumulate the local-crossing histograms incrementally to get the
# running product.
running = [[0.0] * n_alpha for _ in range(n_plus_ens)]
eta_run = [0.0] * n_plus_ens
series = []
for row in matrix:
lambdamax = row[2]
prod = 1.0
for i in range(n_plus_ens):
cxy = row[I0PLUS + i]
eta_run[i] += cxy
lambda_i = interfaces[i]
alpha_max = int(np.floor((lambdamax - lambda_a) / lamres))
alpha_min = round((lambda_i - lambda_a) / lamres)
if alpha_max > n_alpha - 1:
alpha_max = n_alpha - 1
for alpha in range(alpha_min, alpha_max + 1):
running[i][alpha] += cxy
alpha_next = round((interfaces[i + 1] - lambda_a) / lamres)
local = (
running[i][alpha_next] / eta_run[i] if eta_run[i] != 0.0
else 0.0
)
prod *= local
series.append(prod)
return series
[docs]def _running_lengths(matrix):
"""Weighted mean path length in the [0-] and [0+] ensembles."""
i0min = I0PLUS - 1
sum_l0min = sum_l0plus = 0.0
sum_eta0min = sum_eta0plus = 0.0
r0min = r0plus = 0.0
for row in matrix:
length = row[1]
sum_eta0min += row[i0min]
sum_eta0plus += row[I0PLUS]
sum_l0min += row[i0min] * length
sum_l0plus += row[I0PLUS] * length
r0min = sum_l0min / sum_eta0min if sum_eta0min != 0.0 else 0.0
r0plus = sum_l0plus / sum_eta0plus if sum_eta0plus != 0.0 else 0.0
return r0min, r0plus
[docs]def _running_flux(matrix, interval=1.0):
"""Running-average conventional flux, one entry per path.
The ``[0-]``/``[0+]`` path lengths are in simulation steps;
``interval`` (the time per step, ``timestep * subcycles``) converts
the running flux to a flux per unit time. See
:func:`wham_crossing_probability` for the ``interval`` convention.
"""
i0min = I0PLUS - 1
sum_l0min = sum_l0plus = 0.0
sum_eta0min = sum_eta0plus = 0.0
series = []
for row in matrix:
length = row[1]
sum_eta0min += row[i0min]
sum_eta0plus += row[I0PLUS]
sum_l0min += row[i0min] * length
sum_l0plus += row[I0PLUS] * length
r0min = sum_l0min / sum_eta0min if sum_eta0min != 0.0 else 0.0
r0plus = sum_l0plus / sum_eta0plus if sum_eta0plus != 0.0 else 0.0
denom = (r0min + r0plus - 4) * interval
series.append(1.0 / denom if denom > 0 else 0.0)
return series
[docs]def path_length_statistics(matrix):
"""Compute path-length statistics from a data matrix.
Parameters
----------
matrix : list of list of float
Data matrix from :func:`read_data_matrix`.
Returns
-------
stats : dict
``mean``, ``std``, ``min``, ``max``, ``median`` of path lengths.
"""
lengths = [row[1] for row in matrix]
if not lengths:
return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'median': 0}
arr = np.array(lengths)
return {
'mean': float(np.mean(arr)),
'std': float(np.std(arr)),
'min': int(np.min(arr)),
'max': int(np.max(arr)),
'median': float(np.median(arr)),
}
[docs]def analyse_wham_output(data_file, interfaces, lamres=None, nskip=0,
minblocks=5, interval=1.0):
"""Run the standard WHAM analysis on an infinite-swapping run.
Parameters
----------
data_file : str
Path to ``infswap_data.txt``.
interfaces : sequence of float
The interface positions ``[lambda_0, ..., lambda_B]`` (from the
run's ``restart.toml`` / config).
lamres : float, optional
Order-parameter resolution (see
:func:`wham_crossing_probability`).
nskip : int, optional
Number of initial records to skip.
minblocks : int, optional
Minimum number of blocks for the block-error analysis.
interval : float, optional
Simulation time per recorded step (``timestep * subcycles``),
used to convert the flux to a rate per unit time. See
:func:`wham_crossing_probability`; read it from the run config
with :func:`read_engine_interval`. The default ``1.0`` yields a
rate per step.
Returns
-------
results : dict
The dictionary returned by :func:`wham_crossing_probability`,
plus ``path_lengths`` (the :func:`path_length_statistics` dict).
"""
matrix = read_data_matrix(data_file, nskip=nskip)
lengths = path_length_statistics(matrix)
results = wham_crossing_probability(
matrix, interfaces, lamres=lamres, minblocks=minblocks,
interval=interval,
)
results['path_lengths'] = lengths
return results
[docs]def _load_run_config(directory='.'):
"""Load a run's ``restart.toml`` (or ``infswap.toml``) as a dict.
Returns the parsed config from the first of ``restart.toml`` /
``infswap.toml`` found in ``directory``, or ``None`` if neither
exists.
"""
try:
import tomllib
except ImportError: # pragma: no cover - py<3.11 fallback
import tomli as tomllib
for name in ('restart.toml', 'infswap.toml'):
candidate = os.path.join(directory, name)
if os.path.isfile(candidate):
with open(candidate, 'rb') as infile:
return tomllib.load(infile)
return None
[docs]def read_interfaces(directory='.'):
"""Read the interface list from a run's ``restart.toml``.
Parameters
----------
directory : str
Directory containing ``restart.toml`` (or ``infswap.toml``).
Returns
-------
interfaces : list of float or None
The interface positions, or ``None`` if no config is found.
"""
config = _load_run_config(directory)
if config is None:
return None
return config.get('simulation', {}).get('interfaces')
[docs]def read_engine_interval(directory='.'):
"""Read the time per step (``timestep * subcycles``) from a run config.
The infinite-swapping data file records path lengths in steps, so the
WHAM flux/rate needs this interval to be expressed per unit time (the
convention of the native PyRETIS flux analysis). The value is read
from the run's ``[engine]`` section, mirroring
:func:`pyretis.analysis.flux_analysis.analyse_flux`
(``timestep * subcycles``, with ``subcycles`` defaulting to 1).
Parameters
----------
directory : str
Directory containing ``restart.toml`` (or ``infswap.toml``).
Returns
-------
interval : float or None
``timestep * subcycles``, or ``None`` if no config is found or it
has no engine ``timestep`` (so the caller can decide whether to
fail or fall back to a per-step rate rather than silently
assuming a timestep).
"""
config = _load_run_config(directory)
if config is None:
return None
return engine_time_per_step(config)
#: Names of the infinite-swapping data file, newest first. ``infswap_data.txt``
#: is the current name; ``infretis_data.txt`` is the pre-debrand legacy name,
#: still produced by runs whose ``restart.toml`` persists the old ``data_file``
#: (so a restart started before the rename keeps writing the old name).
INFSWAP_DATA_NAMES = ('infswap_data.txt', 'infretis_data.txt')
[docs]def detect_infswap_output(directory='.'):
"""Check if a directory contains infinite-swapping sampler output.
Both the current ``infswap_data.txt`` and the legacy
``infretis_data.txt`` are recognised, so runs created (or restarted)
before the data-file rename are still analysed; the current name takes
precedence when both are present.
Parameters
----------
directory : str
Path to check.
Returns
-------
data_file : str or None
Path to the infinite-swapping data file if found, else None.
"""
for name in INFSWAP_DATA_NAMES:
candidate = os.path.join(directory, name)
if os.path.isfile(candidate):
return candidate
return None