#!/usr/bin/env python3
# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""pyretisanalyse - An application for analysing PyRETIS simulations.
This script is a part of the PyRETIS library and can be used for
analysing the result from simulations.
usage: pyretisanalyse.py [-h] -i INPUT [-V]
optional arguments:
-h, --help show this help message and exit
-i INPUT, --input INPUT
Location of PyRETIS input file
-V, --version show program's version number and exit
"""
# pylint: disable=invalid-name
import argparse
import datetime
import os
import re
import shutil
import subprocess # nosec B404
import sys
import traceback
import warnings
import colorama
# Surface PyRETIS-owned DeprecationWarnings (e.g. the rst input
# deprecation) to CLI users. Python hides DeprecationWarning by
# default unless it comes from __main__; this filter scopes the
# override to pyretis modules only.
warnings.filterwarnings('default', category=DeprecationWarning,
module=r'pyretis(\..*)?')
from pyretis import __version__ as VERSION # noqa: E402
from pyretis.info import PROGRAM_NAME, URL, CITE, LOGO # noqa: E402
from pyretis.inout.common import create_backup # noqa: E402
from pyretis.core.units import CONSTANTS # noqa: E402
from pyretis.core.pathensemble import generate_ensemble_name # noqa: E402
from pyretis.inout.analysisio.analysisio import run_analysis # noqa: E402
from pyretis.inout.common import ( # noqa: E402
check_python_version,
make_dirs,
name_file,
)
from pyretis.inout.formats.formatter import setup_console_logging # noqa: E402
from pyretis.inout.report import generate_report # noqa: E402
from pyretis.inout.settings import parse_settings_file # noqa: E402
from pyretis.inout.screen import REFERENCE # noqa: E402 registers custom levels
_DATE_FMT = '%d.%m.%Y %H:%M:%S'
# Set up for logging:
logger = setup_console_logging()
runpath = os.getcwd()
# Hard-coded patters for report outputs:
REPORTFILES = {
'md-flux': 'md_flux_report',
'retis': 'retis_report',
'make-tis-files': 'tis-multiple_report',
'tis': 'tis_report',
'repptis': 'repptis_report',
}
ERROR_FILE = 'error.txt'
PDFLATEX = 'pdflatex'
[docs]def hello_world(infile, run_dir, report_dir):
"""Output a standard greeting for PyRETIS analysis.
Parameters
----------
infile : string
String showing the location of the input file.
run_dir : string
The location where we are executing the analysis.
report_dir : string
String showing the location of where we write the output.
"""
timestart = datetime.datetime.now().strftime(_DATE_FMT)
pyversion = sys.version.split()[0]
logger.banner('\n'.join([LOGO]))
logger.banner('%s version: %s', PROGRAM_NAME, VERSION)
logger.banner('Start of analysis: %s', timestart)
logger.banner('Python version: %s', pyversion)
logger.progress('\nRunning in directory: %s', run_dir)
logger.progress('Report directory: %s', report_dir)
logger.progress('Input file: %s\n', infile)
[docs]def bye_bye_world():
"""Print out the goodbye message for PyRETIS."""
timeend = datetime.datetime.now().strftime(_DATE_FMT)
msgtxt = f'End of {PROGRAM_NAME} analysis: {timeend}'
logger.progress(msgtxt)
# display some references:
references = [f'{PROGRAM_NAME} references:']
references.append(('-') * len(references[0]))
for line in CITE.split('\n'):
if line:
references.append(line)
reftxt = '\n'.join(references)
logger.log(REFERENCE, '\n%s', reftxt)
urltxt = f'{URL}'
logger.log(REFERENCE, urltxt)
[docs]def write_traceback(filename):
"""Write the error traceback to the given file."""
msg = create_backup(filename)
if msg:
logger.warning(msg)
with open(filename, 'w', encoding='utf-8') as out:
out.write(traceback.format_exc())
[docs]def _report_base(report_type, prefix=None):
"""Return the base name for a report without counter or cycles."""
name = REPORTFILES[report_type]
if prefix is not None:
name = f'{prefix}_{name}'
return name
[docs]def _path_cycles(result):
"""Return the cycle count from a path-ensemble analysis result."""
if not isinstance(result, dict):
return None
out = result.get('out', result)
if isinstance(out, dict):
return out.get('tis-cycles')
return None
[docs]def completed_cycles(analysis_results):
"""Return the number of cycles represented by an analysis result."""
if not isinstance(analysis_results, dict):
return None
for key in ('pathensemble', 'pathensemble_repptis', 'pathensemble0'):
result = analysis_results.get(key)
if isinstance(result, list):
cycles = [_path_cycles(item) for item in result]
else:
cycles = [_path_cycles(result)]
cycles = [cycle for cycle in cycles if cycle is not None]
if cycles:
return min(cycles)
matched = analysis_results.get('matched')
if isinstance(matched, dict):
out = matched.get('out')
if isinstance(out, dict) and 'overall-cycle' in out:
return len(out['overall-cycle'])
cross = analysis_results.get('cross')
if isinstance(cross, dict):
out = cross.get('out', cross)
if isinstance(out, dict):
return out.get('totalcycle')
if 'overall-cycle' in analysis_results:
return len(analysis_results['overall-cycle'])
return None
[docs]def _report_counter_pattern(report_base, extension):
"""Return a regexp matching countered report archive names."""
ext = re.escape(os.extsep + extension)
return re.compile(rf'^{re.escape(report_base)}_(\d{{3}})'
rf'(?:_cycles-\d+)?{ext}$')
[docs]def _latest_report_pattern(report_base, extension):
"""Return a regexp matching uncountered latest report names."""
ext = re.escape(os.extsep + extension)
return re.compile(rf'^{re.escape(report_base)}'
rf'(?:_cycles-\d+)?{ext}$')
[docs]def _next_report_counter(path, report_base, extension):
"""Return the next free archive counter for a report family."""
counters = set()
pattern = _report_counter_pattern(report_base, extension)
if os.path.isdir(path):
for filename in os.listdir(path):
match = pattern.match(filename)
if match:
counters.add(int(match.group(1)))
counter = 0
while counter in counters:
counter += 1
return counter
[docs]def _countered_report_name(reportfile, report_base, counter):
"""Insert the archive counter before the cycle descriptor."""
path = os.path.dirname(reportfile)
filename = os.path.basename(reportfile)
root, ext = os.path.splitext(filename)
cycle_start = root.rfind('_cycles-')
if cycle_start >= 0 and root[:cycle_start] == report_base:
name = f'{report_base}_{counter:03d}{root[cycle_start:]}{ext}'
elif root == report_base:
name = f'{report_base}_{counter:03d}{ext}'
else:
name = f'{root}_{counter:03d}{ext}'
return os.path.join(path, name)
[docs]def backup_latest_reports(reportfile, report_base):
"""Back up uncountered reports from a report family."""
path = os.path.dirname(reportfile) or os.curdir
extension = os.path.splitext(reportfile)[1][1:]
pattern = _latest_report_pattern(report_base, extension)
if not os.path.isdir(path):
return []
reportfiles = [
os.path.join(path, filename)
for filename in sorted(os.listdir(path))
if pattern.match(filename)
]
messages = []
counter = _next_report_counter(path, report_base, extension)
for filename in reportfiles:
backup = _countered_report_name(filename, report_base, counter)
while os.path.exists(backup):
counter += 1
backup = _countered_report_name(filename, report_base, counter)
messages.append(f'Backup existing file "{filename}" to "{backup}"')
os.rename(filename, backup)
counter += 1
return messages
[docs]def get_report_name(report_type, ext, prefix=None, path=None, cycles=None,
counter=None):
"""Generate file name for a report.
Parameters
----------
report_type : string
Identifier for the report we are writing.
ext : string
Extension for the file to write.
prefix : string, optional
A prefix to add to the file name. Usually just used
to mark reports with ensemble number for `report_type` equal
to 'tis-single'
path : string
A directory to use for saving the report to.
cycles : int, optional
Number of completed cycles represented by the report.
counter : int, optional
Archive counter to insert before the cycle descriptor.
Returns
-------
out : string
The name of the file written.
"""
name = _report_base(report_type, prefix=prefix)
if counter is not None:
name = f'{name}_{counter:03d}'
cycle_suffix = _format_cycle_suffix(cycles)
if cycle_suffix is not None:
name = f'{name}_{cycle_suffix}'
return name_file(name, ext, path=path)
[docs]def write_file(outname, report_txt, backup=True, report_base=None):
"""Write a generated report to a given file.
Parameters
----------
outname : string
The name of the file to write/create.
report_txt : string
This is the generated report as a string.
backup : boolean, optional
If True, back up an existing report before writing the new one.
report_base : string, optional
Base report name used to back up the previous uncountered report.
Returns
-------
out : string
The name of the file written.
"""
if backup:
if report_base is None:
msg = create_backup(outname)
if msg:
logger.warning(msg)
else:
for msg in backup_latest_reports(outname, report_base):
logger.warning(msg)
with open(outname, 'wt', encoding='utf-8') as report_fh:
report_fh.write(report_txt)
return outname
[docs]def create_pdf_report(texfile, pdflatex=PDFLATEX, report_base=None):
"""Compile a LaTeX report to PDF if pdflatex is available."""
if shutil.which(pdflatex) is None:
logger.warning(
'Could not create PDF report for "%s": %s was not found.',
os.path.basename(texfile), pdflatex)
return None
pdffile = os.path.splitext(texfile)[0] + os.extsep + 'pdf'
if report_base is None:
msg = create_backup(pdffile)
if msg:
logger.warning(msg)
else:
for msg in backup_latest_reports(pdffile, report_base):
logger.warning(msg)
command = [pdflatex, '-interaction=nonstopmode',
'-halt-on-error', os.path.basename(texfile)]
try:
result = subprocess.run( # nosec B603
command,
cwd=os.path.dirname(texfile) or None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
text=True,
)
except OSError as err:
logger.warning('Could not create PDF report for "%s": %s',
os.path.basename(texfile), err)
return None
if result.returncode != 0:
logger.warning('Could not create PDF report for "%s".',
os.path.basename(texfile))
logger.debug(result.stdout)
logger.debug(result.stderr)
return None
return pdffile
[docs]def create_reports(settings, analysis_results, report_path):
"""Create some reports to display the output.
Parameters
----------
settings : dict
Settings for analysis (and the simulation).
analysis_results : dict
Results from the analysis.
report_path : string
The path to the directory where the reports should be saved.
Yields
------
out : string
The report files created.
"""
ens_list = settings.get('ensemble', [])
if settings['simulation']['task'] == 'tis' and len(ens_list) == 1:
task = 'tis'
ens_n = settings['ensemble'][0]['tis'].get('ensemble_number', '1')
pfix = generate_ensemble_name(ens_n)
else:
task = settings['simulation']['task']
pfix = None
cycles = completed_cycles(analysis_results)
report_base = _report_base(task, prefix=pfix)
for report_type in settings['analysis']['report']:
report, ext = generate_report(task, analysis_results,
output=report_type)
if report is not None:
reportfile = get_report_name(task, ext, prefix=pfix,
path=report_path, cycles=cycles)
write_file(reportfile, report, report_base=report_base)
yield reportfile
if ext == 'tex':
pdffile = create_pdf_report(reportfile,
report_base=report_base)
if pdffile is not None:
yield pdffile
[docs]def _run_wham_analysis(data_file, report_dir, nskip=0):
"""Analyse infinite-swapping output (WHAM crossing probability).
Parameters
----------
data_file : string
Path to the ``infswap_data.txt`` produced by the infinite-swapping
sampler.
report_dir : string
Directory the ``wham_analysis.txt`` report is written to.
nskip : int, optional
Number of initial cycles (records) to discard as equilibration --
the ``skip_initial_cycles`` analysis setting. Defaults to 0.
Returns
-------
int
0 on success, 1 if the interfaces could not be read.
"""
from pyretis.analysis.wham_analysis import (
analyse_wham_output, read_interfaces, read_engine_interval,
)
config_dir = os.path.dirname(os.path.abspath(data_file))
interfaces = read_interfaces(config_dir)
if interfaces is None:
logger.warning('No restart.toml/infswap.toml found next to %s; '
'cannot read interfaces for the WHAM analysis.',
data_file)
return 1
# Path lengths are in steps; convert the flux to a rate per unit time
# with the engine interval (timestep * subcycles). If the timestep is
# unavailable, fall back to a per-step rate but say so loudly.
interval = read_engine_interval(config_dir)
if interval is None:
interval = 1.0
logger.warning('No engine timestep found next to %s; reporting the '
'rate per step (interval=1). Multiply by 1/timestep '
'for a rate per unit time.', data_file)
results = analyse_wham_output(data_file, interfaces, nskip=nskip,
interval=interval)
logger.progress('WHAM analysis of %s', data_file)
if nskip:
logger.progress(' Skipped initial cycles: %d', nskip)
logger.progress(' Records analysed: %d', results['n_records'])
logger.progress(' Total crossing probability (point matching): %.6g '
'(rel. err. %.3g)',
results['pcross_pm'], results['pcross_pm_relerr'])
logger.progress(' Total crossing probability (WHAM): %.6g '
'(rel. err. %.3g)',
results['pcross_wham'], results['pcross_wham_relerr'])
for i, p in enumerate(results['pcross_at_intf']):
logger.progress(' P_A(lambda_%d | lambda_0) [WHAM]: %.6g', i, p)
logger.progress(' Initial flux (interval=%.6g): %.6g',
interval, results['flux'])
logger.progress(' Rate constant (point matching): %.6g',
results['rate_pm'])
logger.progress(' Rate constant (WHAM): %.6g', results['rate_wham'])
stats = results['path_lengths']
logger.progress(' Path lengths: mean=%.1f, median=%.1f, '
'min=%d, max=%d',
stats['mean'], stats['median'],
stats['min'], stats['max'])
outfile = os.path.join(report_dir, 'wham_analysis.txt')
os.makedirs(report_dir, exist_ok=True)
with open(outfile, 'w', encoding='utf-8') as fh:
fh.write(f'# WHAM analysis of {data_file}\n')
fh.write(f'# Records: {results["n_records"]}\n')
fh.write(f'# Skipped initial cycles: {nskip}\n')
fh.write(f'# Interfaces: {interfaces}\n\n')
fh.write('Total crossing probability:\n')
fh.write(f' point matching: {results["pcross_pm"]:.10g} '
f'(rel. err. {results["pcross_pm_relerr"]:.4g})\n')
fh.write(f' WHAM: {results["pcross_wham"]:.10g} '
f'(rel. err. {results["pcross_wham_relerr"]:.4g})\n\n')
fh.write('Cumulative crossing probability P_A(lambda_i|lambda_0) '
'[WHAM]:\n')
for i, p in enumerate(results['pcross_at_intf']):
fh.write(f' interface {i}: {p:.10g}\n')
fh.write(f'\nInitial flux: {results["flux"]:.10g}\n')
rate_pm = results['rate_pm']
fh.write(f'Rate constant (point matching): {rate_pm:.10g}\n')
fh.write(f'Rate constant (WHAM): {results["rate_wham"]:.10g}\n')
fh.write('\nPath length statistics:\n')
for key, val in stats.items():
fh.write(f' {key}: {val}\n')
logger.progress(' Report written: %s', outfile)
return 0
[docs]def main(input_file, run_path, report_dir):
"""Run the analysis.
Parameters
----------
input_file : string
The input file with settings for the analysis.
run_path : string
The location from which we are running the analysis.
report_dir : string
The location where we will write the report.
"""
exit_status = 0
try:
# Auto-detect infinite-swapping output.
from pyretis.analysis.wham_analysis import detect_infswap_output
inf_data = detect_infswap_output(run_path)
if inf_data is not None and input_file is None:
return _run_wham_analysis(inf_data, report_dir)
if input_file is None:
raise FileNotFoundError('Input file required (-i filename).')
if not os.path.isfile(os.path.join(run_path, input_file)):
raise FileNotFoundError(f'Input file "{input_file}" NOT found!')
# Run analysis
logger.progress('Reading input file "%s"', input_file)
settings = parse_settings_file(input_file)
# override exe-path to the one we are executing in now:
settings['simulation']['exe-path'] = run_path
units = settings['system']['units']
# set derived properties:
settings['system']['beta'] = (settings['system']['temperature'] *
CONSTANTS['kB'][units]) ** -1
settings['analysis']['report-dir'] = report_dir
msg_dir = make_dirs(report_dir)
if msg_dir:
logger.progress(msg_dir)
# The [analysis] method selects the rate estimator(s) to run:
# "crossing" (default) -- native point-matching on the
# per-ensemble pathensemble.txt files,
# "wham" -- WHAM on the infinite-swapping
# infswap_data.txt file,
# "both" -- run both and report each.
method = str(settings['analysis'].get('method', 'crossing')).lower()
valid_methods = ('crossing', 'wham', 'both')
if method not in valid_methods:
raise ValueError(
f'Unknown [analysis] method "{method}"; choose one of: '
f'{", ".join(valid_methods)}.'
)
skip_cycles = int(settings['analysis'].get('skip_initial_cycles', 0))
# Native crossing-probability / point-matching analysis. It reads
# skip_initial_cycles through the shared analysis settings (see
# pyretis.analysis.path_analysis.initial_skip).
if method in ('crossing', 'both'):
task = settings['simulation']['task']
if task == 'explore':
logger.warning(
'Task "explore" does not compute crossing probabilities '
'or rate constants — its paths carry no statistical '
'weight. pyretisanalyse cannot analyse an explore run. '
'Switch task to tis / retis / pptis / repptis and '
'restart from the saved paths before running analysis.'
)
if method == 'crossing':
return 1
else:
sep = '=' * (len(task) + 28)
logger.banner('\n%s\n Running %s analysis.\n%s',
sep, task, sep)
results = run_analysis(settings)
logger.progress('\nAnalysis complete. Creating reports:')
for outfile in create_reports(settings, results, report_dir):
relfile = os.path.relpath(outfile, start=run_path)
logger.progress(' Report created: %s', relfile)
# WHAM analysis on the infinite-swapping data file.
if method in ('wham', 'both'):
inf_data = detect_infswap_output(run_path)
if inf_data is None:
msg = (
f'[analysis] method = "{method}" requests the WHAM '
'analysis, but no infswap_data.txt was found in '
f'"{run_path}".'
)
if method == 'wham':
raise FileNotFoundError(msg)
logger.warning('%s Skipping the WHAM analysis.', msg)
else:
wham_status = _run_wham_analysis(
inf_data, report_dir, nskip=skip_cycles
)
if wham_status != 0:
exit_status = wham_status
except Exception as error: # Exceptions should subclass BaseException.
exit_status = 1
errtxt = f'{type(error).__name__}: {error.args}'
logger.error(errtxt)
logger.error('Execution failed!')
logger.error('Error traceback is written to: %s', ERROR_FILE)
write_traceback(os.path.join(run_path, ERROR_FILE))
finally:
bye_bye_world()
return exit_status
[docs]def entry_point(): # pragma: no cover
"""entry_point - The entry point for the pip install of pyretisanalyse."""
colorama.init(autoreset=True)
parser = argparse.ArgumentParser(description=PROGRAM_NAME)
parser.add_argument(
'-i',
'--input',
help=(f'Location of {PROGRAM_NAME} input file'),
required=False,
default=None
)
parser.add_argument('-V', '--version', action='version',
version=f'{PROGRAM_NAME} {VERSION}')
args_dict = vars(parser.parse_args())
check_python_version()
inputfile = args_dict['input']
reportdir = os.path.join(runpath, 'report')
hello_world(inputfile, runpath, reportdir)
sys.exit(main(inputfile, runpath, reportdir))
if __name__ == '__main__': # pragma: no cover
entry_point()