Source code for pyretis.bin.pyretisanalyse

#!/usr/bin/env python3
# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""pyretisanalyse - An application for analysing PyRETIS simulations.

This script is a part of the PyRETIS library and can be used for
analysing the result from simulations.

usage: pyretisanalyse.py [-h] -i INPUT [-V]

optional arguments:
  -h, --help            show this help message and exit
  -i INPUT, --input INPUT
                        Location of PyRETIS input file
  -V, --version         show program's version number and exit
"""
# pylint: disable=invalid-name
import argparse
import datetime
import os
import re
import shutil
import subprocess  # nosec B404
import sys
import traceback
import warnings
import colorama

# Surface PyRETIS-owned DeprecationWarnings (e.g. the rst input
# deprecation) to CLI users. Python hides DeprecationWarning by
# default unless it comes from __main__; this filter scopes the
# override to pyretis modules only.
warnings.filterwarnings('default', category=DeprecationWarning,
                        module=r'pyretis(\..*)?')
from pyretis import __version__ as VERSION  # noqa: E402
from pyretis.info import PROGRAM_NAME, URL, CITE, LOGO  # noqa: E402
from pyretis.inout.common import create_backup  # noqa: E402
from pyretis.core.units import CONSTANTS  # noqa: E402
from pyretis.core.pathensemble import generate_ensemble_name  # noqa: E402
from pyretis.inout.analysisio.analysisio import run_analysis  # noqa: E402
from pyretis.inout.common import (  # noqa: E402
    check_python_version,
    make_dirs,
    name_file,
)
from pyretis.inout.formats.formatter import setup_console_logging  # noqa: E402
from pyretis.inout.report import generate_report  # noqa: E402
from pyretis.inout.settings import parse_settings_file  # noqa: E402
from pyretis.inout.screen import REFERENCE  # noqa: E402 registers custom levels

_DATE_FMT = '%d.%m.%Y %H:%M:%S'

# Set up for logging:
logger = setup_console_logging()

runpath = os.getcwd()

# Hard-coded patters for report outputs:
REPORTFILES = {
    'md-flux': 'md_flux_report',
    'retis': 'retis_report',
    'make-tis-files': 'tis-multiple_report',
    'tis': 'tis_report',
    'repptis': 'repptis_report',
}

ERROR_FILE = 'error.txt'
PDFLATEX = 'pdflatex'


[docs]def hello_world(infile, run_dir, report_dir): """Output a standard greeting for PyRETIS analysis. Parameters ---------- infile : string String showing the location of the input file. run_dir : string The location where we are executing the analysis. report_dir : string String showing the location of where we write the output. """ timestart = datetime.datetime.now().strftime(_DATE_FMT) pyversion = sys.version.split()[0] logger.banner('\n'.join([LOGO])) logger.banner('%s version: %s', PROGRAM_NAME, VERSION) logger.banner('Start of analysis: %s', timestart) logger.banner('Python version: %s', pyversion) logger.progress('\nRunning in directory: %s', run_dir) logger.progress('Report directory: %s', report_dir) logger.progress('Input file: %s\n', infile)
[docs]def bye_bye_world(): """Print out the goodbye message for PyRETIS.""" timeend = datetime.datetime.now().strftime(_DATE_FMT) msgtxt = f'End of {PROGRAM_NAME} analysis: {timeend}' logger.progress(msgtxt) # display some references: references = [f'{PROGRAM_NAME} references:'] references.append(('-') * len(references[0])) for line in CITE.split('\n'): if line: references.append(line) reftxt = '\n'.join(references) logger.log(REFERENCE, '\n%s', reftxt) urltxt = f'{URL}' logger.log(REFERENCE, urltxt)
[docs]def write_traceback(filename): """Write the error traceback to the given file.""" msg = create_backup(filename) if msg: logger.warning(msg) with open(filename, 'w', encoding='utf-8') as out: out.write(traceback.format_exc())
[docs]def _format_cycle_suffix(cycles): """Return a file-name suffix for the number of analysed cycles.""" if cycles is None: return None try: cycles = int(cycles) except (TypeError, ValueError): return None return f'cycles-{cycles:09d}'
[docs]def _report_base(report_type, prefix=None): """Return the base name for a report without counter or cycles.""" name = REPORTFILES[report_type] if prefix is not None: name = f'{prefix}_{name}' return name
[docs]def _path_cycles(result): """Return the cycle count from a path-ensemble analysis result.""" if not isinstance(result, dict): return None out = result.get('out', result) if isinstance(out, dict): return out.get('tis-cycles') return None
[docs]def completed_cycles(analysis_results): """Return the number of cycles represented by an analysis result.""" if not isinstance(analysis_results, dict): return None for key in ('pathensemble', 'pathensemble_repptis', 'pathensemble0'): result = analysis_results.get(key) if isinstance(result, list): cycles = [_path_cycles(item) for item in result] else: cycles = [_path_cycles(result)] cycles = [cycle for cycle in cycles if cycle is not None] if cycles: return min(cycles) matched = analysis_results.get('matched') if isinstance(matched, dict): out = matched.get('out') if isinstance(out, dict) and 'overall-cycle' in out: return len(out['overall-cycle']) cross = analysis_results.get('cross') if isinstance(cross, dict): out = cross.get('out', cross) if isinstance(out, dict): return out.get('totalcycle') if 'overall-cycle' in analysis_results: return len(analysis_results['overall-cycle']) return None
[docs]def _report_counter_pattern(report_base, extension): """Return a regexp matching countered report archive names.""" ext = re.escape(os.extsep + extension) return re.compile(rf'^{re.escape(report_base)}_(\d{{3}})' rf'(?:_cycles-\d+)?{ext}$')
[docs]def _latest_report_pattern(report_base, extension): """Return a regexp matching uncountered latest report names.""" ext = re.escape(os.extsep + extension) return re.compile(rf'^{re.escape(report_base)}' rf'(?:_cycles-\d+)?{ext}$')
[docs]def _next_report_counter(path, report_base, extension): """Return the next free archive counter for a report family.""" counters = set() pattern = _report_counter_pattern(report_base, extension) if os.path.isdir(path): for filename in os.listdir(path): match = pattern.match(filename) if match: counters.add(int(match.group(1))) counter = 0 while counter in counters: counter += 1 return counter
[docs]def _countered_report_name(reportfile, report_base, counter): """Insert the archive counter before the cycle descriptor.""" path = os.path.dirname(reportfile) filename = os.path.basename(reportfile) root, ext = os.path.splitext(filename) cycle_start = root.rfind('_cycles-') if cycle_start >= 0 and root[:cycle_start] == report_base: name = f'{report_base}_{counter:03d}{root[cycle_start:]}{ext}' elif root == report_base: name = f'{report_base}_{counter:03d}{ext}' else: name = f'{root}_{counter:03d}{ext}' return os.path.join(path, name)
[docs]def backup_latest_reports(reportfile, report_base): """Back up uncountered reports from a report family.""" path = os.path.dirname(reportfile) or os.curdir extension = os.path.splitext(reportfile)[1][1:] pattern = _latest_report_pattern(report_base, extension) if not os.path.isdir(path): return [] reportfiles = [ os.path.join(path, filename) for filename in sorted(os.listdir(path)) if pattern.match(filename) ] messages = [] counter = _next_report_counter(path, report_base, extension) for filename in reportfiles: backup = _countered_report_name(filename, report_base, counter) while os.path.exists(backup): counter += 1 backup = _countered_report_name(filename, report_base, counter) messages.append(f'Backup existing file "{filename}" to "{backup}"') os.rename(filename, backup) counter += 1 return messages
[docs]def get_report_name(report_type, ext, prefix=None, path=None, cycles=None, counter=None): """Generate file name for a report. Parameters ---------- report_type : string Identifier for the report we are writing. ext : string Extension for the file to write. prefix : string, optional A prefix to add to the file name. Usually just used to mark reports with ensemble number for `report_type` equal to 'tis-single' path : string A directory to use for saving the report to. cycles : int, optional Number of completed cycles represented by the report. counter : int, optional Archive counter to insert before the cycle descriptor. Returns ------- out : string The name of the file written. """ name = _report_base(report_type, prefix=prefix) if counter is not None: name = f'{name}_{counter:03d}' cycle_suffix = _format_cycle_suffix(cycles) if cycle_suffix is not None: name = f'{name}_{cycle_suffix}' return name_file(name, ext, path=path)
[docs]def write_file(outname, report_txt, backup=True, report_base=None): """Write a generated report to a given file. Parameters ---------- outname : string The name of the file to write/create. report_txt : string This is the generated report as a string. backup : boolean, optional If True, back up an existing report before writing the new one. report_base : string, optional Base report name used to back up the previous uncountered report. Returns ------- out : string The name of the file written. """ if backup: if report_base is None: msg = create_backup(outname) if msg: logger.warning(msg) else: for msg in backup_latest_reports(outname, report_base): logger.warning(msg) with open(outname, 'wt', encoding='utf-8') as report_fh: report_fh.write(report_txt) return outname
[docs]def create_pdf_report(texfile, pdflatex=PDFLATEX, report_base=None): """Compile a LaTeX report to PDF if pdflatex is available.""" if shutil.which(pdflatex) is None: logger.warning( 'Could not create PDF report for "%s": %s was not found.', os.path.basename(texfile), pdflatex) return None pdffile = os.path.splitext(texfile)[0] + os.extsep + 'pdf' if report_base is None: msg = create_backup(pdffile) if msg: logger.warning(msg) else: for msg in backup_latest_reports(pdffile, report_base): logger.warning(msg) command = [pdflatex, '-interaction=nonstopmode', '-halt-on-error', os.path.basename(texfile)] try: result = subprocess.run( # nosec B603 command, cwd=os.path.dirname(texfile) or None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False, text=True, ) except OSError as err: logger.warning('Could not create PDF report for "%s": %s', os.path.basename(texfile), err) return None if result.returncode != 0: logger.warning('Could not create PDF report for "%s".', os.path.basename(texfile)) logger.debug(result.stdout) logger.debug(result.stderr) return None return pdffile
[docs]def create_reports(settings, analysis_results, report_path): """Create some reports to display the output. Parameters ---------- settings : dict Settings for analysis (and the simulation). analysis_results : dict Results from the analysis. report_path : string The path to the directory where the reports should be saved. Yields ------ out : string The report files created. """ ens_list = settings.get('ensemble', []) if settings['simulation']['task'] == 'tis' and len(ens_list) == 1: task = 'tis' ens_n = settings['ensemble'][0]['tis'].get('ensemble_number', '1') pfix = generate_ensemble_name(ens_n) else: task = settings['simulation']['task'] pfix = None cycles = completed_cycles(analysis_results) report_base = _report_base(task, prefix=pfix) for report_type in settings['analysis']['report']: report, ext = generate_report(task, analysis_results, output=report_type) if report is not None: reportfile = get_report_name(task, ext, prefix=pfix, path=report_path, cycles=cycles) write_file(reportfile, report, report_base=report_base) yield reportfile if ext == 'tex': pdffile = create_pdf_report(reportfile, report_base=report_base) if pdffile is not None: yield pdffile
[docs]def _run_wham_analysis(data_file, report_dir, nskip=0): """Analyse infinite-swapping output (WHAM crossing probability). Parameters ---------- data_file : string Path to the ``infswap_data.txt`` produced by the infinite-swapping sampler. report_dir : string Directory the ``wham_analysis.txt`` report is written to. nskip : int, optional Number of initial cycles (records) to discard as equilibration -- the ``skip_initial_cycles`` analysis setting. Defaults to 0. Returns ------- int 0 on success, 1 if the interfaces could not be read. """ from pyretis.analysis.wham_analysis import ( analyse_wham_output, read_interfaces, read_engine_interval, ) config_dir = os.path.dirname(os.path.abspath(data_file)) interfaces = read_interfaces(config_dir) if interfaces is None: logger.warning('No restart.toml/infswap.toml found next to %s; ' 'cannot read interfaces for the WHAM analysis.', data_file) return 1 # Path lengths are in steps; convert the flux to a rate per unit time # with the engine interval (timestep * subcycles). If the timestep is # unavailable, fall back to a per-step rate but say so loudly. interval = read_engine_interval(config_dir) if interval is None: interval = 1.0 logger.warning('No engine timestep found next to %s; reporting the ' 'rate per step (interval=1). Multiply by 1/timestep ' 'for a rate per unit time.', data_file) results = analyse_wham_output(data_file, interfaces, nskip=nskip, interval=interval) logger.progress('WHAM analysis of %s', data_file) if nskip: logger.progress(' Skipped initial cycles: %d', nskip) logger.progress(' Records analysed: %d', results['n_records']) logger.progress(' Total crossing probability (point matching): %.6g ' '(rel. err. %.3g)', results['pcross_pm'], results['pcross_pm_relerr']) logger.progress(' Total crossing probability (WHAM): %.6g ' '(rel. err. %.3g)', results['pcross_wham'], results['pcross_wham_relerr']) for i, p in enumerate(results['pcross_at_intf']): logger.progress(' P_A(lambda_%d | lambda_0) [WHAM]: %.6g', i, p) logger.progress(' Initial flux (interval=%.6g): %.6g', interval, results['flux']) logger.progress(' Rate constant (point matching): %.6g', results['rate_pm']) logger.progress(' Rate constant (WHAM): %.6g', results['rate_wham']) stats = results['path_lengths'] logger.progress(' Path lengths: mean=%.1f, median=%.1f, ' 'min=%d, max=%d', stats['mean'], stats['median'], stats['min'], stats['max']) outfile = os.path.join(report_dir, 'wham_analysis.txt') os.makedirs(report_dir, exist_ok=True) with open(outfile, 'w', encoding='utf-8') as fh: fh.write(f'# WHAM analysis of {data_file}\n') fh.write(f'# Records: {results["n_records"]}\n') fh.write(f'# Skipped initial cycles: {nskip}\n') fh.write(f'# Interfaces: {interfaces}\n\n') fh.write('Total crossing probability:\n') fh.write(f' point matching: {results["pcross_pm"]:.10g} ' f'(rel. err. {results["pcross_pm_relerr"]:.4g})\n') fh.write(f' WHAM: {results["pcross_wham"]:.10g} ' f'(rel. err. {results["pcross_wham_relerr"]:.4g})\n\n') fh.write('Cumulative crossing probability P_A(lambda_i|lambda_0) ' '[WHAM]:\n') for i, p in enumerate(results['pcross_at_intf']): fh.write(f' interface {i}: {p:.10g}\n') fh.write(f'\nInitial flux: {results["flux"]:.10g}\n') rate_pm = results['rate_pm'] fh.write(f'Rate constant (point matching): {rate_pm:.10g}\n') fh.write(f'Rate constant (WHAM): {results["rate_wham"]:.10g}\n') fh.write('\nPath length statistics:\n') for key, val in stats.items(): fh.write(f' {key}: {val}\n') logger.progress(' Report written: %s', outfile) return 0
[docs]def main(input_file, run_path, report_dir): """Run the analysis. Parameters ---------- input_file : string The input file with settings for the analysis. run_path : string The location from which we are running the analysis. report_dir : string The location where we will write the report. """ exit_status = 0 try: # Auto-detect infinite-swapping output. from pyretis.analysis.wham_analysis import detect_infswap_output inf_data = detect_infswap_output(run_path) if inf_data is not None and input_file is None: return _run_wham_analysis(inf_data, report_dir) if input_file is None: raise FileNotFoundError('Input file required (-i filename).') if not os.path.isfile(os.path.join(run_path, input_file)): raise FileNotFoundError(f'Input file "{input_file}" NOT found!') # Run analysis logger.progress('Reading input file "%s"', input_file) settings = parse_settings_file(input_file) # override exe-path to the one we are executing in now: settings['simulation']['exe-path'] = run_path units = settings['system']['units'] # set derived properties: settings['system']['beta'] = (settings['system']['temperature'] * CONSTANTS['kB'][units]) ** -1 settings['analysis']['report-dir'] = report_dir msg_dir = make_dirs(report_dir) if msg_dir: logger.progress(msg_dir) # The [analysis] method selects the rate estimator(s) to run: # "crossing" (default) -- native point-matching on the # per-ensemble pathensemble.txt files, # "wham" -- WHAM on the infinite-swapping # infswap_data.txt file, # "both" -- run both and report each. method = str(settings['analysis'].get('method', 'crossing')).lower() valid_methods = ('crossing', 'wham', 'both') if method not in valid_methods: raise ValueError( f'Unknown [analysis] method "{method}"; choose one of: ' f'{", ".join(valid_methods)}.' ) skip_cycles = int(settings['analysis'].get('skip_initial_cycles', 0)) # Native crossing-probability / point-matching analysis. It reads # skip_initial_cycles through the shared analysis settings (see # pyretis.analysis.path_analysis.initial_skip). if method in ('crossing', 'both'): task = settings['simulation']['task'] if task == 'explore': logger.warning( 'Task "explore" does not compute crossing probabilities ' 'or rate constants — its paths carry no statistical ' 'weight. pyretisanalyse cannot analyse an explore run. ' 'Switch task to tis / retis / pptis / repptis and ' 'restart from the saved paths before running analysis.' ) if method == 'crossing': return 1 else: sep = '=' * (len(task) + 28) logger.banner('\n%s\n Running %s analysis.\n%s', sep, task, sep) results = run_analysis(settings) logger.progress('\nAnalysis complete. Creating reports:') for outfile in create_reports(settings, results, report_dir): relfile = os.path.relpath(outfile, start=run_path) logger.progress(' Report created: %s', relfile) # WHAM analysis on the infinite-swapping data file. if method in ('wham', 'both'): inf_data = detect_infswap_output(run_path) if inf_data is None: msg = ( f'[analysis] method = "{method}" requests the WHAM ' 'analysis, but no infswap_data.txt was found in ' f'"{run_path}".' ) if method == 'wham': raise FileNotFoundError(msg) logger.warning('%s Skipping the WHAM analysis.', msg) else: wham_status = _run_wham_analysis( inf_data, report_dir, nskip=skip_cycles ) if wham_status != 0: exit_status = wham_status except Exception as error: # Exceptions should subclass BaseException. exit_status = 1 errtxt = f'{type(error).__name__}: {error.args}' logger.error(errtxt) logger.error('Execution failed!') logger.error('Error traceback is written to: %s', ERROR_FILE) write_traceback(os.path.join(run_path, ERROR_FILE)) finally: bye_bye_world() return exit_status
[docs]def entry_point(): # pragma: no cover """entry_point - The entry point for the pip install of pyretisanalyse.""" colorama.init(autoreset=True) parser = argparse.ArgumentParser(description=PROGRAM_NAME) parser.add_argument( '-i', '--input', help=(f'Location of {PROGRAM_NAME} input file'), required=False, default=None ) parser.add_argument('-V', '--version', action='version', version=f'{PROGRAM_NAME} {VERSION}') args_dict = vars(parser.parse_args()) check_python_version() inputfile = args_dict['input'] reportdir = os.path.join(runpath, 'report') hello_world(inputfile, runpath, reportdir) sys.exit(main(inputfile, runpath, reportdir))
if __name__ == '__main__': # pragma: no cover entry_point()