Source code for pyretis.analysis.wham_analysis

# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""WHAM crossing-probability analysis for infinite-swapping RETIS.

This module computes per-interface and total crossing probabilities, the
initial flux, and the rate constant from the path-data file written by
the infinite-swapping (replica-exchange) sampler
(:func:`pyretis.simulation.repex.write_path_ensemble_data`; the file is
named ``infswap_data.txt``).

The estimator follows the standard weighted-histogram (WHAM)
crossing-probability procedure for transition interface sampling: the
high-acceptance (HA) weight unweighting, the ``eta`` normalisation, the
Q-factor WHAM stitch (Lervik et al., J. Comput. Chem. 2015), a
point-matching cross-check and a block-error analysis. The
implementation returns its results as values rather than writing report
files.

Data-file column layout (after ``str.split`` on a data line):

* ``0`` -- path index
* ``1`` -- path length
* ``2`` -- maximum order parameter (``lambda_max``)
* ``3 .. 3 + nintf - 1`` -- ``Cxy`` (fractional sampling occurrence) for
  ensembles ``[0-]``, ``[0+]``, ``[1+]``, ...
* ``3 + nintf .. 3 + 2*nintf - 1`` -- the high-acceptance (HA) weights for
  the same ensembles.

Here ``nintf`` is the number of interfaces and ``i0plus = 4`` is the column
of the ``[0+]`` ensemble.
"""
import logging
import os

import numpy as np

from pyretis.core.engine_time import engine_time_per_step

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# Column of the [0+] ensemble in a split data line (column 0 is the path
# index). [0-] sits one column to the left at ``I0PLUS - 1``.
I0PLUS = 4


[docs]def read_data_matrix(filename, nskip=0): """Read an ``infswap_data.txt`` file into a numeric matrix. Parameters ---------- filename : str Path to the ``infswap_data.txt`` file. nskip : int, optional Number of initial data rows to discard (loaded paths / equilibration). Returns ------- matrix : list of list of float One row per path; ``"----"`` entries are mapped to ``0.0``. """ matrix = [] with open(filename, encoding='utf-8') as infile: for line in infile: if line.startswith('#'): continue stripped = line.strip() if not stripped: continue values = [ float(item) if item != '----' else 0.0 for item in stripped.split() ] matrix.append(values) del matrix[:nskip] return matrix
[docs]def rec_block_errors(runav, minblocks): """Block-error analysis of a running-average series. Standard recursive block-error analysis of the running average. Parameters ---------- runav : sequence of float The running average of an observable, one entry per path. minblocks : int Minimum number of blocks in the block-error analysis. Returns ------- half_av_err : float Average relative error over the second half of the block lengths. n_stat_ineff : float Statistical inefficiency estimate. rel_errors : list of float Relative error as a function of block length. """ runav = np.asarray(runav, dtype=float) if len(runav) < 2 * minblocks: return 0.0, 0.0, [] maxbll = int(len(runav) / minblocks) # maximum block length bestav = runav[-1] # most accurate average we have if bestav == 0.0: return 0.0, 0.0, [] rel_errors = [] for blocklength in range(1, maxbll + 1): runav_red = runav[blocklength - 1::blocklength] blocks = _rec_blocks(runav_red) sum_qudiff = np.sum((blocks - bestav) ** 2) n_blocks = len(blocks) if n_blocks < 2: continue abs_err = np.sqrt(sum_qudiff / (n_blocks * (n_blocks - 1))) rel_errors.append(abs_err / bestav) if not rel_errors: return 0.0, 0.0, [] second_half = rel_errors[len(rel_errors) // 2:] half_av_err = float(np.mean(second_half)) n_stat_ineff = (half_av_err / rel_errors[0]) ** 2 return half_av_err, n_stat_ineff, rel_errors
[docs]def _rec_blocks(reduced): """Recover block averages from a reduced running-average array.""" reduced = np.asarray(reduced, dtype=float) n_red = len(reduced) result = np.zeros(n_red, dtype=reduced.dtype) result += np.arange(1, n_red + 1) * reduced result[1:] -= np.arange(1, n_red) * reduced[:-1] return result
[docs]def _unweight_matrix(matrix, nintf): """Unweight the ``Cxy`` columns by the HA-weights, in place. Standard WHAM unweighting step: each ``Cxy`` value is divided by its HA-weight, the (now redundant) HA-weight column is overwritten with the running sum of the pre-unweighting ``Cxy`` (needed for the running WHAM averages), and finally each ``Cxy`` column is divided by the average inverse HA-weight so the ``eta`` values are comparable across ensembles that use different moves (e.g. ``[0+]`` shooting vs ``[i+]`` wire fencing). Parameters ---------- matrix : list of list of float The data matrix from :func:`read_data_matrix`; modified in place. nintf : int Number of interfaces. """ i0min = I0PLUS - 1 sum_pxy = [0.0] * nintf # sum of Cxy before unweighting sum_pxy_afterw = [0.0] * nintf # sum of Cxy after unweighting for row in matrix: for ens in range(nintf): cxy_idx = i0min + ens weight_idx = cxy_idx + nintf if row[weight_idx] > 0: # non-zero HA-weight sum_pxy[ens] += row[cxy_idx] row[cxy_idx] /= row[weight_idx] sum_pxy_afterw[ens] += row[cxy_idx] elif row[cxy_idx] > 0: raise ValueError( f'Division by zero HA-weight for path row {row}.' ) # Overwrite the HA-weight column with the running sum of the # pre-unweighting Cxy; the HA-weights are no longer needed. row[weight_idx] = sum_pxy[ens] for ens in range(nintf): if sum_pxy[ens] == 0.0: continue av_inv_w = sum_pxy_afterw[ens] / sum_pxy[ens] cxy_idx = i0min + ens for row in matrix: row[cxy_idx] /= av_inv_w
[docs]def _wham_ptot_run(interfaces, ploc_matrix, sum_pxy): """Running-average total crossing probability via WHAM. ``ploc_matrix[j][i]`` is ``P_A(lambda_i | lambda_j)`` from the ``[j+]`` ensemble using the data so far. """ n_ens = len(interfaces) - 1 # lambda_n == lambda_B q_array = [0.0] * n_ens prob = 1.0 inv_q = sum_pxy[0] * prob if inv_q == 0: return 0.0, q_array q_fac = 1.0 / inv_q q_array[0] = q_fac for i in range(1, n_ens + 1): nominator = 0.0 for j in range(i): nominator += sum_pxy[j] * ploc_matrix[j][i] prob = nominator * q_fac if i == n_ens or prob == 0: return prob, q_array inv_q += sum_pxy[i] / prob q_fac = 1.0 / inv_q q_array[i] = q_fac return prob, q_array
[docs]def _wham_pq(n_plus_ens, interfaces, lamres, eta, v_alpha): """WHAM crossing probabilities at the interfaces and Q-factors. WHAM crossing probabilities at the interfaces and Q-factors (Q-factor recursion of Lervik et al., JCTC 2015). Returns ``P`` (crossing probability at each interface, with ``lambda_B`` appended) and ``Q`` (the per-ensemble normalisation factors for ``v_alpha``). """ prob = [0.0] * n_plus_ens q_fac = [0.0] * n_plus_ens inv_q = [0.0] * n_plus_ens prob[0] = 1.0 inv_q[0] = eta[0] if inv_q[0] == 0: # No [0+] sampling at all: every deeper interface is unreachable, # so its crossing probability is zero. Append the lambda_B slot # (as the normal path does) so the returned list always has one # entry per interface -- never nintf - 1. prob.append(0.0) return prob, q_fac q_fac[0] = 1.0 / inv_q[0] lambda_a = interfaces[0] for i in range(1, n_plus_ens): alpha = round((interfaces[i] - lambda_a) / lamres) prob[i] = v_alpha[alpha] * q_fac[i - 1] if prob[i] == 0: # Nothing reaches interface i; the deeper probabilities are # already 0. Pad the lambda_B slot so len(prob) == nintf, as # on the normal exit below. prob.append(0.0) return prob, q_fac inv_q[i] = inv_q[i - 1] + (eta[i] / prob[i]) q_fac[i] = 1.0 / inv_q[i] prob.append(v_alpha[-1] * q_fac[n_plus_ens - 1]) return prob, q_fac
[docs]def _default_lamres(interfaces): """Pick a default order-parameter resolution from the interfaces. The WHAM stitch places every interface on a uniform grid of step ``lamres`` via ``round((lambda - lambda_A) / lamres)``. To keep adjacent interfaces on *distinct* grid points the step must be small relative to the *smallest* interface spacing -- not the first one. Using the first gap (the historical default) silently collapses two interfaces onto the same grid index whenever a later gap is smaller, which corrupts the per-interface crossing probabilities. We therefore take one tenth of the smallest spacing. Parameters ---------- interfaces : list of float Strictly increasing interface positions. Returns ------- lamres : float The chosen resolution. """ gaps = [ interfaces[i + 1] - interfaces[i] for i in range(len(interfaces) - 1) ] min_gap = min(gaps) if min_gap <= 0.0: raise ValueError( 'Interfaces must be strictly increasing; found a ' f'non-positive spacing in {interfaces}.' ) return min_gap / 10.0
[docs]def _validate_lamres(interfaces, lamres): """Check that ``lamres`` resolves every interface separately. Raises if two interfaces round to the same grid index (the stitch would double-count an ensemble) and warns if an interface does not sit exactly on the grid (the per-interface crossing probability is then read at the nearest grid point; the interface is mis-placed by at most ``lamres / 2`` in lambda, and the probability error that induces depends on the local crossing-probability slope). Failing loud here is deliberate: a wrong crossing probability is worse than a refusal to run. Parameters ---------- interfaces : list of float Strictly increasing interface positions. lamres : float Candidate order-parameter resolution. """ if lamres <= 0.0: raise ValueError(f'lamres must be positive, got {lamres}.') lambda_a = interfaces[0] indices = [] for lam in interfaces: steps = (lam - lambda_a) / lamres nearest = round(steps) indices.append(nearest) if abs(steps - nearest) > 1.0e-6 * max(1.0, abs(steps)): logger.warning( 'Interface %s is not commensurate with lamres=%s: ' '(lambda - lambda_A)/lamres = %s is not an integer, so ' 'its crossing probability is read at the nearest grid ' 'point. Use a lamres that evenly divides every interface ' 'spacing for exact placement.', lam, lamres, steps, ) if len(set(indices)) != len(indices): raise ValueError( f'lamres={lamres} is too coarse: two interfaces map to the ' 'same grid point, which would corrupt the WHAM stitch. ' 'Reduce lamres (it must be smaller than the closest ' 'interface spacing).' )
[docs]def wham_crossing_probability(matrix, interfaces, lamres=None, minblocks=5, interval=1.0): """Compute WHAM crossing probabilities and rate from a data matrix. Core WHAM crossing-probability, flux and rate computation (crossing probability + flux + rate), with the file/plot output replaced by a returned dictionary. The matrix is modified in place by the HA-weight unweighting step. Parameters ---------- matrix : list of list of float Data matrix from :func:`read_data_matrix` (already ``nskip``-ped). interfaces : sequence of float The interface positions ``[lambda_0, ..., lambda_B]``. lamres : float, optional Order-parameter resolution. Defaults to one tenth of the *smallest* interface spacing (see :func:`_default_lamres`); this keeps non-uniform interfaces on distinct grid points. Whatever value is used is validated by :func:`_validate_lamres`. minblocks : int, optional Minimum number of blocks for the block-error analysis. interval : float, optional Simulation time per recorded path step, i.e. ``timestep * subcycles``. The ``[0-]``/``[0+]`` path lengths are in steps, so the flux is divided by this to give a rate per unit time -- matching the native PyRETIS flux analysis (:func:`pyretis.analysis.flux_analysis.analyse_flux`). The default ``1.0`` reproduces the upstream inftools convention of a rate per step; pass the real interval to obtain a physical rate. Returns ------- results : dict Keys: ``pcross_pm`` / ``pcross_wham`` (total crossing probability, point-matching and WHAM), ``pcross_pm_relerr`` / ``pcross_wham_relerr`` (relative block errors), ``pcross_at_intf`` (WHAM crossing probability at each interface), ``pcross_at_intf_pm`` (point-matching version), ``flux``, ``rate_pm`` / ``rate_wham``, ``rate_pm_relerr``, ``length_0minus`` / ``length_0plus``, ``n_records``, ``lambda_values`` and ``pcross_curve_wham`` / ``pcross_curve_pm`` (the full ``P_A(lambda)`` profiles). """ if not matrix: raise ValueError('Empty path-data matrix (all rows skipped?).') interfaces = [float(item) for item in interfaces] if len(interfaces) < 2: raise ValueError( f'Need at least two interfaces, got {interfaces}.' ) if lamres is None: lamres = _default_lamres(interfaces) _validate_lamres(interfaces, lamres) lambda_a = interfaces[0] lambda_b = interfaces[-1] nintf = len(interfaces) n_plus_ens = nintf - 1 _unweight_matrix(matrix, nintf) lambda_values = [ i * lamres for i in range(round(lambda_a / lamres), round(lambda_b / lamres) + 1) ] n_alpha = len(lambda_values) v_alpha = [0.0] * n_alpha # WHAM total crossing probability v_alpha[0] = 1.0 u_alpha = [0.0] * n_alpha # point-matching total crossing prob u_alpha[0] = 1.0 p_loc = [[0.0] * n_alpha for _ in range(n_plus_ens)] eta = [0.0] * n_plus_ens run_av_ptot_wham = [] run_av_q = [] ploc_runav = [[0.0] * nintf for _ in range(nintf)] for row in matrix: lambdamax = row[2] for i in range(n_plus_ens): cxy = row[I0PLUS + i] eta[i] += cxy lambda_i = interfaces[i] alpha_max = int(np.floor((lambdamax - lambda_a) / lamres)) alpha_min = round((lambda_i - lambda_a) / lamres) if alpha_max > n_alpha - 1: alpha_max = n_alpha - 1 for alpha in range(alpha_min, alpha_max + 1): p_loc[i][alpha] += cxy # v(alpha) at the next interface is set by the lower ensemble. for alpha in range(alpha_min + 1, alpha_max + 1): v_alpha[alpha] += cxy ploc_runav[i][0:i] = [0.0] * i for j in range(i, nintf): alpha_j = round((interfaces[j] - lambda_a) / lamres) ploc_runav[i][j] = ( p_loc[i][alpha_j] / eta[i] if eta[i] != 0.0 else 0.0 ) sum_pxy = row[-n_plus_ens:] ptot_wham, q_array = _wham_ptot_run(interfaces, ploc_runav, sum_pxy) run_av_ptot_wham.append(ptot_wham) run_av_q.append(q_array) # Point-matching running total: product of local crossing probs. run_av_ptot_pm = _running_pm(matrix, interfaces, lamres) # Final WHAM normalisation of v_alpha and the per-interface probs. pcross_at_intf, q_fac = _wham_pq( n_plus_ens, interfaces, lamres, eta, v_alpha ) p_loc = [ [val / eta[i] if eta[i] != 0.0 else 0.0 for val in p_loc[i]] for i in range(n_plus_ens) ] for k in range(n_plus_ens): alpha_min = round((interfaces[k] - lambda_a) / lamres) + 1 alpha_max = round((interfaces[k + 1] - lambda_a) / lamres) for alpha in range(alpha_min, alpha_max + 1): v_alpha[alpha] *= q_fac[k] # Point-matching profile and per-interface probabilities. pcross_at_intf_pm = [0.0] * nintf pcross_at_intf_pm[0] = 1.0 for i in range(n_plus_ens): alpha_min = round((interfaces[i] - lambda_a) / lamres) + 1 alpha_max = round((interfaces[i + 1] - lambda_a) / lamres) u_alpha[alpha_min:alpha_max + 1] = [ pcross_at_intf_pm[i] * num for num in p_loc[i][alpha_min:alpha_max + 1] ] pcross_at_intf_pm[i + 1] = u_alpha[alpha_max] # Initial flux from the [0-] and [0+] path lengths. Those lengths are # in simulation steps, so multiply by ``interval`` (the time per step, # ``timestep * subcycles``) to get a flux per unit time, matching the # native PyRETIS flux convention. Upstream inftools omits this factor # and so reports a rate per step. length_0min, length_0plus = _running_lengths(matrix) eff_time = (length_0min + length_0plus - 4) * interval flux = 1.0 / eff_time if eff_time > 0 else 0.0 pcross_pm = run_av_ptot_pm[-1] pcross_wham = run_av_ptot_wham[-1] rate_pm = flux * pcross_pm rate_wham = flux * pcross_wham pm_err, _, _ = rec_block_errors(run_av_ptot_pm, minblocks) wham_err, _, _ = rec_block_errors(run_av_ptot_wham, minblocks) flux_series = _running_flux(matrix, interval) runav_rate_pm = [ flux_t * p_t for flux_t, p_t in zip(flux_series, run_av_ptot_pm) ] runav_rate_wham = [ flux_t * p_t for flux_t, p_t in zip(flux_series, run_av_ptot_wham) ] rate_err, _, _ = rec_block_errors(runav_rate_pm, minblocks) return { 'pcross_pm': pcross_pm, 'pcross_wham': pcross_wham, 'pcross_pm_relerr': pm_err, 'pcross_wham_relerr': wham_err, 'pcross_at_intf': pcross_at_intf, 'pcross_at_intf_pm': pcross_at_intf_pm, 'flux': flux, 'rate_pm': rate_pm, 'rate_wham': rate_wham, 'rate_pm_relerr': rate_err, 'length_0minus': length_0min, 'length_0plus': length_0plus, 'n_records': len(matrix), # Per-path running-average series (for rate-vs-cycles plots). 'pcross_pm_series': run_av_ptot_pm, 'pcross_wham_series': run_av_ptot_wham, 'rate_pm_series': runav_rate_pm, 'rate_wham_series': runav_rate_wham, 'lambda_values': lambda_values, 'pcross_curve_wham': v_alpha, 'pcross_curve_pm': u_alpha, }
[docs]def _running_pm(matrix, interfaces, lamres): """Compute the running point-matching total crossing probability. Running point-matching total crossing probability: at each path the product over ensembles of ``P_A(lambda_{i+1} | lambda_i)`` evaluated from the local crossing probabilities accumulated so far. """ lambda_a = interfaces[0] nintf = len(interfaces) n_plus_ens = nintf - 1 n_alpha = round((interfaces[-1] - lambda_a) / lamres) + 1 # Re-accumulate the local-crossing histograms incrementally to get the # running product. running = [[0.0] * n_alpha for _ in range(n_plus_ens)] eta_run = [0.0] * n_plus_ens series = [] for row in matrix: lambdamax = row[2] prod = 1.0 for i in range(n_plus_ens): cxy = row[I0PLUS + i] eta_run[i] += cxy lambda_i = interfaces[i] alpha_max = int(np.floor((lambdamax - lambda_a) / lamres)) alpha_min = round((lambda_i - lambda_a) / lamres) if alpha_max > n_alpha - 1: alpha_max = n_alpha - 1 for alpha in range(alpha_min, alpha_max + 1): running[i][alpha] += cxy alpha_next = round((interfaces[i + 1] - lambda_a) / lamres) local = ( running[i][alpha_next] / eta_run[i] if eta_run[i] != 0.0 else 0.0 ) prod *= local series.append(prod) return series
[docs]def _running_lengths(matrix): """Weighted mean path length in the [0-] and [0+] ensembles.""" i0min = I0PLUS - 1 sum_l0min = sum_l0plus = 0.0 sum_eta0min = sum_eta0plus = 0.0 r0min = r0plus = 0.0 for row in matrix: length = row[1] sum_eta0min += row[i0min] sum_eta0plus += row[I0PLUS] sum_l0min += row[i0min] * length sum_l0plus += row[I0PLUS] * length r0min = sum_l0min / sum_eta0min if sum_eta0min != 0.0 else 0.0 r0plus = sum_l0plus / sum_eta0plus if sum_eta0plus != 0.0 else 0.0 return r0min, r0plus
[docs]def _running_flux(matrix, interval=1.0): """Running-average conventional flux, one entry per path. The ``[0-]``/``[0+]`` path lengths are in simulation steps; ``interval`` (the time per step, ``timestep * subcycles``) converts the running flux to a flux per unit time. See :func:`wham_crossing_probability` for the ``interval`` convention. """ i0min = I0PLUS - 1 sum_l0min = sum_l0plus = 0.0 sum_eta0min = sum_eta0plus = 0.0 series = [] for row in matrix: length = row[1] sum_eta0min += row[i0min] sum_eta0plus += row[I0PLUS] sum_l0min += row[i0min] * length sum_l0plus += row[I0PLUS] * length r0min = sum_l0min / sum_eta0min if sum_eta0min != 0.0 else 0.0 r0plus = sum_l0plus / sum_eta0plus if sum_eta0plus != 0.0 else 0.0 denom = (r0min + r0plus - 4) * interval series.append(1.0 / denom if denom > 0 else 0.0) return series
[docs]def path_length_statistics(matrix): """Compute path-length statistics from a data matrix. Parameters ---------- matrix : list of list of float Data matrix from :func:`read_data_matrix`. Returns ------- stats : dict ``mean``, ``std``, ``min``, ``max``, ``median`` of path lengths. """ lengths = [row[1] for row in matrix] if not lengths: return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'median': 0} arr = np.array(lengths) return { 'mean': float(np.mean(arr)), 'std': float(np.std(arr)), 'min': int(np.min(arr)), 'max': int(np.max(arr)), 'median': float(np.median(arr)), }
[docs]def analyse_wham_output(data_file, interfaces, lamres=None, nskip=0, minblocks=5, interval=1.0): """Run the standard WHAM analysis on an infinite-swapping run. Parameters ---------- data_file : str Path to ``infswap_data.txt``. interfaces : sequence of float The interface positions ``[lambda_0, ..., lambda_B]`` (from the run's ``restart.toml`` / config). lamres : float, optional Order-parameter resolution (see :func:`wham_crossing_probability`). nskip : int, optional Number of initial records to skip. minblocks : int, optional Minimum number of blocks for the block-error analysis. interval : float, optional Simulation time per recorded step (``timestep * subcycles``), used to convert the flux to a rate per unit time. See :func:`wham_crossing_probability`; read it from the run config with :func:`read_engine_interval`. The default ``1.0`` yields a rate per step. Returns ------- results : dict The dictionary returned by :func:`wham_crossing_probability`, plus ``path_lengths`` (the :func:`path_length_statistics` dict). """ matrix = read_data_matrix(data_file, nskip=nskip) lengths = path_length_statistics(matrix) results = wham_crossing_probability( matrix, interfaces, lamres=lamres, minblocks=minblocks, interval=interval, ) results['path_lengths'] = lengths return results
[docs]def _load_run_config(directory='.'): """Load a run's ``restart.toml`` (or ``infswap.toml``) as a dict. Returns the parsed config from the first of ``restart.toml`` / ``infswap.toml`` found in ``directory``, or ``None`` if neither exists. """ try: import tomllib except ImportError: # pragma: no cover - py<3.11 fallback import tomli as tomllib for name in ('restart.toml', 'infswap.toml'): candidate = os.path.join(directory, name) if os.path.isfile(candidate): with open(candidate, 'rb') as infile: return tomllib.load(infile) return None
[docs]def read_interfaces(directory='.'): """Read the interface list from a run's ``restart.toml``. Parameters ---------- directory : str Directory containing ``restart.toml`` (or ``infswap.toml``). Returns ------- interfaces : list of float or None The interface positions, or ``None`` if no config is found. """ config = _load_run_config(directory) if config is None: return None return config.get('simulation', {}).get('interfaces')
[docs]def read_engine_interval(directory='.'): """Read the time per step (``timestep * subcycles``) from a run config. The infinite-swapping data file records path lengths in steps, so the WHAM flux/rate needs this interval to be expressed per unit time (the convention of the native PyRETIS flux analysis). The value is read from the run's ``[engine]`` section, mirroring :func:`pyretis.analysis.flux_analysis.analyse_flux` (``timestep * subcycles``, with ``subcycles`` defaulting to 1). Parameters ---------- directory : str Directory containing ``restart.toml`` (or ``infswap.toml``). Returns ------- interval : float or None ``timestep * subcycles``, or ``None`` if no config is found or it has no engine ``timestep`` (so the caller can decide whether to fail or fall back to a per-step rate rather than silently assuming a timestep). """ config = _load_run_config(directory) if config is None: return None return engine_time_per_step(config)
#: Names of the infinite-swapping data file, newest first. ``infswap_data.txt`` #: is the current name; ``infretis_data.txt`` is the pre-debrand legacy name, #: still produced by runs whose ``restart.toml`` persists the old ``data_file`` #: (so a restart started before the rename keeps writing the old name). INFSWAP_DATA_NAMES = ('infswap_data.txt', 'infretis_data.txt')
[docs]def detect_infswap_output(directory='.'): """Check if a directory contains infinite-swapping sampler output. Both the current ``infswap_data.txt`` and the legacy ``infretis_data.txt`` are recognised, so runs created (or restarted) before the data-file rename are still analysed; the current name takes precedence when both are present. Parameters ---------- directory : str Path to check. Returns ------- data_file : str or None Path to the infinite-swapping data file if found, else None. """ for name in INFSWAP_DATA_NAMES: candidate = os.path.join(directory, name) if os.path.isfile(candidate): return candidate return None