# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""
Common functions for the path density.
Functions used to compare and process data, such as matching similar
lists or attempting periodic shifts of values.
Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
find_rst_file (:py:func: `.find_rst_file`)
Search for a rst-file from a chosen subdirectory.
read_traj_txt_file (:py:func: `.read_traj_txt_file`)
Read the sequence of files in a trajectory from a traj.txt file.
recalculate_all (:py:func:`.recalculate_all`)
Recalculate order parameter and new collective variables by finding
all trajectory files from a simulation.
shift_data (:py:func: `.shift_data`)
Finds the median value of a given list of floats, and shifts the
lower half of the data by the median.
try_data_shift (:py:func: `.try_data_shift`)
Takes in two lists of values, ``x`` and ``y``, and calculates a
linear regression and R**2-correlation of the data set. Attempts a
shift of each data set by their respective median to increase the
correlation.
where_from_to (:py:func: `.where_from_to`)
Check the initial and final steps of a trajectory with respect to
the provided interfaces.
get_cv_names (:py:func: `.get_cv_names`)
Outputs a list of the names of the descriptors in the simulation.
recalculate_all (:py:func: `.recalculate_all`)
Recompute all the order parameters according to the PyRETIS storage
scheme or for individual files/folders.
find_data (:py:func: `.find_data`)
Find suitable frames/trajectories to recompute the order parameter on.
read_single_order_txt (:py:func: `.read_single_order_txt`)
Parse a standalone order parameter text file and return a DataFrame
together with a list of column names suitable for use in PyVisA.
run_user_script (:py:func: `.run_user_script`)
Execute a user-supplied Python script and capture its stdout output
to produce an order.txt file that PyVisA can load.
"""
import concurrent.futures
import io
import json
import logging
import os
import re
import shutil
import subprocess # nosec B404
import sys
import tempfile
import timeit
import tqdm
import numpy as np
import pandas as pd
import scipy # pylint: disable=import-error
from pyretis.initiation.initiate_load import write_order_parameters
from pyretis.inout import settings
from pyretis.inout.common import create_backup, TRJ_FORMATS
from pyretis.inout.formats.path import PathExtFile
from pyretis.orderparameter import expand_order_names
from pyretis.setup.common import create_orderparameter
from pyretis.tools.recalculate_order import recalculate_order
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
__all__ = ['find_rst_file', 'read_traj_txt_file',
'shift_data', 'try_data_shift', 'where_from_to',
'get_cv_names', 'recalculate_all', 'find_data',
'read_single_data_file', 'read_single_order_txt',
'run_user_script']
[docs]def try_data_shift(x, y, fixedx):
"""Check if shifting increases correlation.
Function that checks if correlation of data increases by shifting
either sets of values, x or y, or both. Correlation is checked by
doing a simple linear regression on the different sets of data:
- x and y , x and yshift, xshift and y, xshift and yshift.
If linear correlation increases (r-squared value), data sets are
updated.
As a precaution, no shift
is performed on x values if they are of the first order parameter 'op1'.
Parameters
----------
x, y : list
Floats, data values
fixedx : bool
If True, x is main OP and should not be shifted.
Returns
-------
x, y : list
Floats, updated (or unchanged) data values
(If changed, returns x_temp or y_temp or both)
"""
# The unshifted data
_, _, r_val, _, _ = scipy.stats.linregress(x, y)
# The Y-shifted data
y_temp = shift_data(y)
_, _, r_y, _, _ = scipy.stats.linregress(x, y_temp)
yshift = r_y**2 > r_val**2
# The X-shifted data
x_temp = shift_data(x)
_, _, r_x, _, _ = scipy.stats.linregress(x_temp, y)
xshift = r_x**2 > r_val**2 and r_x**2 > r_y**2
# Comparing effectiveness of both shifts individually, and combined
_, _, r_xy, _, _ = scipy.stats.linregress(x_temp, y_temp)
xyshift = r_xy**2 > r_val**2 and r_xy**2 > r_y**2 and r_xy**2 > r_x**2
# If first op is op1, don't shift data
if xyshift and not fixedx:
return x_temp, y_temp
if xshift and not fixedx:
return x_temp, y
if yshift:
return x, y_temp
return x, y
[docs]def shift_data(x):
"""Shifts the data under the median.
Function that takes in a list of data, and shifts all values
below the median value of the data by the max difference,
effectively shifting parts of the data periodically in order to
give clusters for visualization.
Parameters
----------
x : list
Floats, data values
Returns
-------
xnorm : list
Floats where some values are shifted values of x,
and some are left unchanged.
"""
xmin, xmax = min(x), max(x)
xnorm = []
# The max difference in x-data
diff_x = xmax - xmin
# The Median of x-data
medix = xmin + 0.5 * diff_x
for i in x:
if i < medix:
xnorm.append(i + diff_x)
else:
xnorm.append(i)
return xnorm
[docs]def read_traj_txt_file(path):
"""Read a traj.txt file.
Function which reads a traj.txt file and returns a dict containing
the name of each file in the trajectory and the sign of their velocity.
Parameters
----------
path : string
Path to the traj.txt file.
Returns
-------
files : dict
Dictionary containing each file in the trajectory and
the sign of their velocity.
"""
files = {}
i = 0
with PathExtFile(path, 'r') as pfile:
for block in pfile.load():
for data in block['data']:
if data[0] == '0':
files[i] = [data[1], data[3]]
if data[1] != files[i][0]:
i += 1
files[i] = [data[1], data[3]]
return files
[docs]def find_rst_file(search_dir):
"""Search for rst-files.
Parameters
----------
search_dir : string
Path to the .rst file.
Returns
-------
out[0] : string
Path and name of the .rst file.
"""
current = os.path.abspath(search_dir)
while True:
try:
for file_name in sorted(os.listdir(current)):
if file_name.endswith('.rst'):
return os.path.join(current, file_name)
except OSError:
return current
parent = os.path.dirname(current)
if parent == current:
return current
current = parent
[docs]def where_from_to(trj, int_a, int_b=float('-inf')):
r"""Detect L∕R starts and L / R / \* ends.
Given a list of order parameters (a trj), the function
will try to establish where the path started (L or R or \*)
and where it ended.
Note: for the 'REJ' paths, this function results might differ
from PyRETIS.
Parameters
----------
trj: numpy array
The order parameters of the trj.
int_a: float
The interface that defines state A.
int_b: float, optional
The interface that defines state B. If not given, it is assumed
that the 0^- ensemble is in use without the 0^- L interface.
Returns
-------
start: string\*1
The initial position of the trajectory in respect to the
interfaces given (L eft, R ight or \* for nothing).
end: string\*1
The final position of the trajectory in respect to the
interfaces given (L eft, R ight or \* for nothing).
"""
start, end = '*', '*'
int_l = min(int_a, int_b)
int_t = max(int_a, int_b)
if trj[0] >= int_t:
start = 'R'
if trj[0] < int_l:
start = 'L'
if trj[-1] >= int_t:
end = 'R'
if trj[-1] < int_l:
end = 'L'
return start, end
[docs]def get_cv_names(input_settings, num_columns=None):
"""Return labels for the order parameter and collective variables.
The labels follow the same convention for the main order parameter
and any extra collective variable: each ``[orderparameter]`` /
``[collective-variable]`` block may declare ``name`` either as a
single string or as a list of strings.
* A list of strings is used as-is for that block (one label per
value returned by the corresponding ``calculate()``).
* A single string is expanded with an index suffix
(``"<name>_1", "<name>_2", ...``) to match the number of values
the block produces. If the block produces a single value the
bare string is used.
* When ``name`` is missing the labels fall back to ``"op_<i>"`` for
the main order parameter and ``"cv_<i>"`` for each collective
variable (using the bare prefix if the block produces a single
value).
Parameters
----------
input_settings : dict
Dictionary with the settings from the simulations.
num_columns : int, optional
Total number of order-parameter columns observed (e.g. read
from ``order.txt``). When supplied, this is used to allocate
unspecified blocks: if there is exactly one section configured
and ``num_columns`` is greater than one, the single ``name`` is
expanded with indices to cover all columns. When ``num_columns``
is given and cannot be reconciled with the configured names,
a :py:class:`ValueError` is raised.
Returns
-------
names : list of str
Flat list of column labels, in the same order as the
concatenated ``calculate()`` outputs.
Raises
------
ValueError
If a ``name`` list does not match the number of values
produced by its block, or if ``num_columns`` cannot be
reconciled with the configured names.
"""
blocks = []
main_block = input_settings.get('orderparameter')
if main_block is not None:
blocks.append(('op', main_block, main_block.get('name')))
cv_blocks = input_settings.get('collective-variable', [])
for idx, c_v in enumerate(cv_blocks, start=1):
prefix = 'cv' if len(cv_blocks) == 1 else f'cv{idx}'
blocks.append((prefix, c_v, c_v.get('name')))
if not blocks:
return []
# When the only information available is the column count from a
# file (e.g. order.txt) and a single section is configured, that
# section claims all columns.
if num_columns is not None and len(blocks) == 1:
prefix, _, name = blocks[0]
return expand_order_names(name, num_columns, default_prefix=prefix)
labels = []
for prefix, _, name in blocks:
if isinstance(name, (list, tuple)):
labels.extend(expand_order_names(name, len(name), prefix))
else:
labels.extend(expand_order_names(name, 1, prefix))
if num_columns is not None and len(labels) != num_columns:
raise ValueError(
"Configured order-parameter names produce "
f"{len(labels)} labels, but the data has {num_columns} "
"columns. Provide a list ``name`` for any multi-valued "
"order parameter or collective variable."
)
return labels
[docs]def _split_table_fields(line):
"""Split a text row using comma or whitespace delimiters."""
text = line.strip()
if not text:
return []
if ',' in text:
return [part.strip() for part in text.split(',') if part.strip()]
return text.split()
[docs]def _parse_numeric_row(line):
"""Return numeric values from a row, or ``None`` if parsing fails."""
fields = _split_table_fields(line)
if not fields:
return None
try:
return [float(value) for value in fields]
except ValueError:
return None
[docs]def _parse_header_candidate(line):
"""Extract a possible header row and whether it was commented."""
stripped = line.strip()
if not stripped:
return None, False
numeric = _parse_numeric_row(stripped)
if numeric is not None:
return None, False
is_comment = re.match(r'^[^A-Za-z0-9.+-]+', stripped) is not None
cleaned = re.sub(r'^[^A-Za-z0-9]+', '', stripped)
fields = _split_table_fields(cleaned)
return (fields or None), is_comment
[docs]def _labels_from_rst(rst_file, n_cols):
"""Infer column labels from an optional PyRETIS ``.rst`` file."""
if not rst_file or not os.path.isfile(rst_file):
return None
try:
input_settings = settings.parse_settings_file(
os.path.abspath(rst_file)
)
except (OSError, ValueError, KeyError):
return None
labels = get_cv_names(input_settings)
if n_cols == len(labels):
return labels
if n_cols > 1 and n_cols == len(labels) + 1:
return ['time'] + labels
return None
[docs]def read_single_data_file(filepath, rst_file=None):
"""Parse a standalone ``txt`` or ``csv`` data file for PyVisA.
The file may contain:
* numeric rows separated by whitespace or commas,
* an optional commented header line (``#``, ``##``, ``;``, ``//``, etc.),
* or an uncommented CSV header row.
When no usable header is present, column titles are inferred from
*rst_file* when possible. Otherwise, PyVisA falls back to
``time``, ``op1``, ``op2``, ...
Parameters
----------
filepath : str
Path to the text/CSV file.
rst_file : str, optional
Optional PyRETIS ``.rst`` file used to infer missing labels.
Returns
-------
frames : pandas.DataFrame or None
Parsed numeric data with every column preserved, including the
first time/step column when present.
plot_cols : list of str or None
Column labels available for plotting.
main_op_label : str or None
Best-effort main order-parameter label for interface/range logic.
"""
rows = []
comment_headers = []
plain_headers = []
with open(filepath, encoding='utf-8') as fh:
for line in fh:
stripped = line.strip()
if not stripped:
continue
numeric = _parse_numeric_row(stripped)
if numeric is not None:
rows.append(numeric)
continue
header, is_comment = _parse_header_candidate(stripped)
if header:
if is_comment:
comment_headers.append(header)
else:
plain_headers.append(header)
if not rows:
return None, None, None
n_cols = len(rows[0])
rows = [row for row in rows if len(row) == n_cols]
if not rows:
return None, None, None
col_names = _select_header_labels(comment_headers, plain_headers, n_cols)
if col_names is None:
col_names = _labels_from_rst(rst_file, n_cols)
if col_names is None:
if n_cols == 1:
col_names = ['op1']
else:
col_names = ['time'] + [f'op{i}' for i in range(1, n_cols)]
frames = pd.DataFrame(rows, columns=col_names, dtype=float)
time_like = col_names[0].lower() in ('time', 't', 'step', 'cycle')
if time_like and len(col_names) > 1:
main_op_label = col_names[1]
else:
main_op_label = col_names[0]
return frames, col_names, main_op_label
[docs]def read_single_order_txt(filepath, rst_file=None):
"""Backward-compatible wrapper for standalone text-table loading."""
frames, plot_cols, _ = read_single_data_file(filepath, rst_file=rst_file)
return frames, plot_cols
[docs]def run_user_script(script_path):
# pylint: disable=too-many-return-statements,too-many-locals
"""Execute a user script and capture order-parameter data from stdout.
The script must print its result to *stdout* in one of two formats:
* **JSON list of lists** – ``[[op1_frame0, op2_frame0, …], …]``
where each inner list contains the order-parameter values for one
simulation frame.
* **CSV table** – a comma- or whitespace-delimited table that
``pandas.read_csv`` can parse (one row per frame, optional header).
The captured data are written to ``order.txt`` in the same directory
as *script_path* so that PyVisA can subsequently load the file.
Parameters
----------
script_path : str
Absolute path to the Python script to execute.
Returns
-------
order_txt_path : str or None
Absolute path of the written ``order.txt`` file, or ``None`` on
failure.
error : str or None
Human-readable error description, or ``None`` on success.
"""
script_dir = os.path.dirname(os.path.abspath(script_path))
try:
result = subprocess.run( # nosec B603
[sys.executable, script_path],
cwd=script_dir,
capture_output=True, text=True, timeout=600,
check=False,
)
except subprocess.TimeoutExpired:
return None, 'Script timed out after 600 seconds.'
except OSError as exc:
return None, str(exc)
if result.returncode != 0:
return None, (
f'Script exited with code {result.returncode}:\n{result.stderr}'
)
stdout = result.stdout.strip()
if not stdout:
return None, 'Script produced no output on stdout.'
# Try JSON first, then fall back to CSV
rows = None
try:
data = json.loads(stdout)
if isinstance(data, list):
rows = [list(r) for r in data]
except json.JSONDecodeError:
pass
if rows is None:
try:
df_out = pd.read_csv(
io.StringIO(stdout),
header=None,
sep=r'\s*,\s*|\s+',
engine='python',
)
rows = df_out.values.tolist()
except Exception: # pylint: disable=broad-except
return None, (
'Could not parse script output as JSON list of lists or CSV.'
)
if not rows:
return None, 'Script returned empty data.'
n_cols = len(rows[0]) if rows else 0
op_cols = ['Orderp'] + [f'cv{j}' for j in range(1, n_cols)]
order_txt_path = os.path.join(script_dir, 'order.txt')
with open(order_txt_path, 'w', encoding='utf-8') as fh:
header_cols = ' '.join(f'{c:>14}' for c in op_cols)
fh.write('Recalculated data\n')
fh.write(f'#{"Time":>12} {header_cols}\n')
for i, row in enumerate(rows):
vals = ' '.join(f'{float(v):14.6f}' for v in row)
fh.write(f'{i:12d} {vals}\n')
# Verify we can read it back
frames, _, _ = read_single_data_file(order_txt_path)
if frames is None:
return None, 'order.txt was written but could not be re-parsed.'
_ = np.asarray(frames) # trigger numpy import for early error detection
return order_txt_path, None
# 50 MB threshold for trajectory splitting
_LARGE_TRJ_BYTES = 50 * 1024 * 1024
# Minimum frames per chunk when splitting
_MIN_CHUNK_FRAMES = 30
[docs]def _count_gromacs_frames(filename, gmx_exe='gmx'):
"""Return (n_frames, dt_ps) for a GROMACS trajectory via gmx check.
Parses both values from the same summary table row so a single
subprocess call is sufficient and ``gmx dump`` is never needed.
Parameters
----------
filename : str
Path to the .trr or .xtc trajectory file.
gmx_exe : str, optional
Path to the GROMACS executable.
Returns
-------
n_frames : int or None
dt_ps : float or None
Frame count and timestep in ps, or (None, None) on failure.
"""
try:
result = subprocess.run( # nosec B603
[gmx_exe, 'check', '-f', filename],
capture_output=True, text=True, timeout=120, check=False,
)
output = result.stdout + result.stderr
for line in output.splitlines():
# GROMACS 2019+: "Step <n_frames> <dt_ps>" summary table
m = re.match(r'\s*Step\s+(\d+)\s+([\d.eE+\-]+)', line)
if m:
n = int(m.group(1))
dt = float(m.group(2))
if n > 0 and dt > 0:
return n, dt
# Older GROMACS: "Highest frame number: N" (0-based), no dt
m = re.search(r'[Hh]ighest frame.*?(\d+)', line)
if m:
return int(m.group(1)) + 1, None
# Parsing succeeded but no frame count found — log the raw output
logger.warning(
'gmx check ran but frame count not found in output for %s '
'(gmx: %s). Raw output:\n%s',
os.path.basename(filename), gmx_exe, output[:500])
except FileNotFoundError:
logger.warning(
'gmx executable not found: "%s". '
'Cannot split large trajectories.', gmx_exe)
except (subprocess.SubprocessError, subprocess.TimeoutExpired) as exc:
logger.warning(
'gmx check failed for %s: %s', os.path.basename(filename), exc)
return None, None
[docs]def _split_large_trajectory(trj, n_chunks, engine_info):
"""Split a large GROMACS trajectory into sub-trajectory files.
Uses a single ``gmx trjconv -split`` pass to write all chunks.
The caller must clean up the returned *tmpdir* when done.
Parameters
----------
trj : str
Path to the source .trr or .xtc file.
n_chunks : int
Desired number of output chunks.
engine_info : dict
Must contain a ``'gmx'`` key with the GROMACS executable path.
Returns
-------
tmpdir : str or None
Temporary directory holding the chunk files, or None on failure.
chunks : list of str or None
Sorted paths to the chunk files, or None on failure.
"""
ext = os.path.splitext(trj)[1].lower()
if ext not in ('.trr', '.xtc'):
return None, None
gmx_exe = engine_info.get('gmx', 'gmx')
n_frames, dt = _count_gromacs_frames(trj, gmx_exe)
if not n_frames or n_frames < 2:
logger.warning(
' Cannot split %s: frame count unavailable or too few frames '
'(got %s). Falling back to single task.',
os.path.basename(trj), n_frames)
return None, None
if dt is None or dt <= 0:
logger.warning(
' Cannot split %s: timestep not available from gmx check '
'(got dt=%s). Falling back to single task.',
os.path.basename(trj), dt)
return None, None
# Cap chunks at n_workers; ensure each chunk has at least
# _MIN_CHUNK_FRAMES frames.
n_chunks = min(n_chunks, n_frames // _MIN_CHUNK_FRAMES)
if n_chunks < 2:
logger.warning(
' Cannot split %s: only %d frames, need at least %d for 2 '
'chunks of %d frames. Falling back to single task.',
os.path.basename(trj), n_frames,
2 * _MIN_CHUNK_FRAMES, _MIN_CHUNK_FRAMES)
return None, None
chunk_frames = n_frames // n_chunks
split_time = chunk_frames * dt
tmpdir = tempfile.mkdtemp(prefix='pyvisa_split_')
# Write an index file selecting all frames (1-based for GROMACS).
ndx_file = os.path.join(tmpdir, 'all_frames.ndx')
with open(ndx_file, 'w', encoding='utf-8') as ndxf:
ndxf.write('[ frames ]\n')
for idx in range(n_frames):
ndxf.write(f'{idx + 1}\n')
output_base = os.path.join(tmpdir, f'seg{ext}')
subprocess.run( # nosec B603
[gmx_exe, 'trjconv', '-f', trj, '-o', output_base,
'-fr', ndx_file, '-split', str(split_time)],
input=b'0\n',
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
check=False,
)
# Collect and sort output chunks by the trailing number GROMACS adds.
def _chunk_sort_key(path):
nums = re.findall(r'(\d+)', os.path.basename(path).replace(ext, ''))
return int(nums[-1]) if nums else 0
chunks = sorted(
(os.path.join(tmpdir, f) for f in os.listdir(tmpdir)
if f.startswith('seg') and f.endswith(ext)),
key=_chunk_sort_key,
)
if not chunks:
logger.warning(
' gmx trjconv -split produced no output for %s '
'(gmx: %s, split_time: %.4f ps). Falling back to single task.',
os.path.basename(trj), gmx_exe, split_time)
shutil.rmtree(tmpdir, ignore_errors=True)
return None, None
return tmpdir, chunks
[docs]def _recalculate_trajectory(task):
"""Process one trajectory file (used by ProcessPoolExecutor).
One task = one trajectory file so the main process can update the
per-ensemble bar after every single trajectory completes.
Parameters
----------
task : dict
Keys: ``ens_name``, ``cycle_key``, ``trj_idx``, ``trj``,
``input_settings``.
Returns
-------
ens_name : str
cycle_key : str
trj_idx : int
Position of this trajectory within its cycle (used to reassemble
results in the original order).
results : list
Order-parameter values computed from this trajectory.
ok : bool
warnings : list of str
"""
# Silence sub-WARNING logging in workers so output doesn't interleave
# with the tqdm bars in the main process.
logging.disable(logging.WARNING - 1)
ens_name = task['ens_name']
cycle_key = task['cycle_key']
trj_idx = task['trj_idx']
chunk_idx = task.get('chunk_idx', 0)
trj = task['trj']
warnings_out = []
results = []
try:
functions = create_orderparameter(task['input_settings'])
except (ImportError, ValueError) as exc:
return (ens_name, cycle_key, trj_idx, chunk_idx,
results, False, [str(exc)])
try:
for order_p in recalculate_order(functions, trj, {}):
results.append(order_p)
except (KeyError, AttributeError):
warnings_out.append(f'File {trj} not valid')
return ens_name, cycle_key, trj_idx, chunk_idx, results, True, warnings_out
[docs]def recalculate_all( # pylint: disable=too-many-locals
runfolder, iofile, ensemble_names=None, data=None,
progress=False, n_workers=1):
"""Recalculate order parameter and collective variables.
Performs post-processing by analyzing trajectories of old simulations
to extract data, do new calculations, and write to a new order.txt file.
Parameters
----------
runfolder: string
The path of the execution directory.
iofile: string
The input file where the settings are collected.
ensemble_names: list, optional
List of ensemble names in the simulation to work with.
data: string, optional
If given, the function will check only the single file or
look only in the given directory.
progress : bool, optional
If True, display tqdm progress bars.
n_workers : int, optional
Number of parallel worker processes. Values > 1 process
ensembles concurrently via :class:`ProcessPoolExecutor`.
Returns
-------
out : boolean
True if the recomputation was successful, False otherwise.
"""
if iofile is None:
raise ValueError('Input file not given')
input_settings = settings.parse_settings_file(
os.path.join(runfolder, iofile)
)
trj_dict = find_data(runfolder, ensemble_names, data=data)
logger.progress('Re-computing the collective variables.')
tic = timeit.default_timer()
if not trj_dict:
logger.warning(
'No data to re-compute from in %s — no trajectory data found. '
'Ensure the run folder contains numeric ensemble subdirectories '
'(e.g. 000, 001) or pass the `data` argument pointing to a '
'file or directory to re-process.', runfolder)
return False
n_workers = max(1, n_workers)
# Log order-parameter info once in the main process, before any bars.
try:
functions = create_orderparameter(input_settings)
except (ImportError, ValueError):
logger.warning('Invalid Order Parameter') # pragma: no cover
return False # pragma: no cover
def _make_bars(ens_names_counts):
"""Create one tqdm bar per ensemble, sized by trajectory count."""
return {
ens_name: tqdm.tqdm(
total=n_trajs,
desc=f'Ens {ens_name}',
unit='traj',
position=i,
leave=True,
disable=not progress,
)
for i, (ens_name, n_trajs) in enumerate(ens_names_counts)
}
if n_workers > 1:
gmx_exe = input_settings.get('engine', {}).get('gmx', 'gmx')
engine_info = {'gmx': gmx_exe}
logger.progress('GROMACS executable for splitting: "%s"', gmx_exe)
# Initialise per-ensemble metadata; pending counts filled below.
ens_meta = {}
for ens_name, ens in trj_dict.items():
cycles_items = [(k, v) for k, v in sorted(ens['traj'].items())
if v.get('traj')]
ens_meta[ens_name] = {
'main_o': ens.get('main_o'),
'cycle_orders': {},
'cycles': {
cycle_key: {
'meta': cycles,
'pending': 0,
'results': {}, # {trj_idx: {chunk_idx: list}}
'tmpdirs': {}, # {trj_idx: tmpdir} for split trajs
}
for cycle_key, cycles in cycles_items
},
}
# Collect every trajectory with its file size, then sort largest
# first so the heaviest work is always scheduled first.
all_trj_info = []
for ens_name, ens in trj_dict.items():
for cycle_key, cycles in sorted(ens['traj'].items()):
if not cycles.get('traj'):
continue
for trj_idx, trj in enumerate(cycles['traj']):
size = os.path.getsize(trj) if os.path.isfile(trj) else 0
all_trj_info.append(
(size, ens_name, cycle_key, trj_idx, trj))
all_trj_info.sort(key=lambda x: x[0], reverse=True)
# Build task groups. Each large trajectory becomes its own group
# so all workers focus on it before moving to the next one.
# Small trajectories share a single batch at the end.
split_groups = [] # list of task-lists, one per large trajectory
normal_tasks = [] # all non-split trajectories
for size, ens_name, cycle_key, trj_idx, trj in all_trj_info:
cycle_data = ens_meta[ens_name]['cycles'][cycle_key]
ext = os.path.splitext(trj)[1].lower()
chunk_paths = None
if size > _LARGE_TRJ_BYTES and ext in ('.trr', '.xtc'):
logger.progress(
'Large trajectory detected: %s (%.0f MB) — '
'splitting into chunks, please wait...',
os.path.basename(trj), size / 1024 / 1024)
tmpdir, chunk_paths = _split_large_trajectory(
trj, n_workers, engine_info)
if chunk_paths:
cycle_data['tmpdirs'][trj_idx] = tmpdir
logger.progress(
' Done — %d chunks ready.', len(chunk_paths))
if chunk_paths:
group = [
{'ens_name': ens_name, 'cycle_key': cycle_key,
'trj_idx': trj_idx, 'chunk_idx': ci, 'is_split': True,
'trj': cp, 'input_settings': input_settings}
for ci, cp in enumerate(chunk_paths)
]
split_groups.append(group)
cycle_data['pending'] += len(chunk_paths)
else:
normal_tasks.append({
'ens_name': ens_name, 'cycle_key': cycle_key,
'trj_idx': trj_idx, 'chunk_idx': 0,
'trj': trj, 'input_settings': input_settings,
})
cycle_data['pending'] += 1
bars = _make_bars([
(ens_name,
sum(c['pending'] for c in meta['cycles'].values()))
for ens_name, meta in ens_meta.items()
])
def _accumulate(fut_result):
"""Incorporate one completed future into ens_meta and bars."""
ens_n, cyc_k, t_idx, c_idx, res, ok, warns = fut_result
for w in warns:
logger.warning('%s cycle %s: %s', ens_n, cyc_k, w)
if not ok:
logger.warning(
'Ensemble %s cycle %s trajectory %d failed',
ens_n, cyc_k, t_idx)
bars[ens_n].update(1)
cd = ens_meta[ens_n]['cycles'][cyc_k]
if t_idx not in cd['results']:
cd['results'][t_idx] = {}
cd['results'][t_idx][c_idx] = res
cd['pending'] -= 1
if cd['pending'] == 0:
cycs = cd['meta']
here = os.path.dirname(os.path.abspath(cycs['traj'][0]))
new_o = os.path.join(here, 'order.txt')
results_dict = []
for ti in sorted(cd['results']):
for ci2 in sorted(cd['results'][ti]):
results_dict.extend(cd['results'][ti][ci2])
local_order = cycs.get('o_txt', new_o)
create_backup(local_order)
write_order_parameters(
local_order, results_dict,
cycs.get('header', 'Recalculated data'))
ens_meta[ens_n]['cycle_orders'][cyc_k] = local_order
for td in cd['tmpdirs'].values():
shutil.rmtree(td, ignore_errors=True)
with concurrent.futures.ProcessPoolExecutor(
max_workers=n_workers) as executor:
# Each large trajectory is its own barrier: all workers finish
# it before the next one starts (largest → smallest).
for group in split_groups:
futures = [executor.submit(_recalculate_trajectory, t)
for t in group]
for future in concurrent.futures.as_completed(futures):
_accumulate(future.result())
# All remaining small trajectories run together in one batch.
if normal_tasks:
futures = [executor.submit(_recalculate_trajectory, t)
for t in normal_tasks]
for future in concurrent.futures.as_completed(futures):
_accumulate(future.result())
for bar in bars.values():
bar.close()
# Merge per-cycle order.txt into each ensemble's main_order
# in sorted cycle order (ordering is lost during parallel execution).
for ens_name, meta in ens_meta.items():
main_order = meta['main_o']
cycle_orders = meta['cycle_orders']
if main_order is None or not cycle_orders:
continue
create_backup(main_order)
with open(main_order, 'ab') as dst:
for cycle_key in sorted(cycle_orders):
with open(cycle_orders[cycle_key], 'rb') as src:
dst.write(src.read())
else:
ens_counts = [
(ens_name,
sum(len(v.get('traj', [])) for v in ens['traj'].values()))
for ens_name, ens in trj_dict.items()
]
bars = _make_bars(ens_counts)
for ens_name, ens in trj_dict.items():
main_order = ens.get('main_o')
main_backed_up = False
for _, cycles in sorted(ens['traj'].items()):
if not cycles.get('traj'):
continue
here = os.path.dirname(os.path.abspath(cycles['traj'][0]))
new_o = os.path.join(here, 'order.txt')
results_dict = []
for trj in cycles['traj']:
try:
for order_p in recalculate_order(functions, trj, {}):
results_dict.append(order_p) # pragma: no cover
except (KeyError, AttributeError):
logger.warning('File %s not valid', trj)
bars[ens_name].update(1)
local_order = cycles.get('o_txt', new_o)
create_backup(local_order)
write_order_parameters(local_order, results_dict,
cycles.get('header',
'Recalculated data'))
if main_order is not None:
if not main_backed_up:
create_backup(main_order)
main_backed_up = True
with open(local_order, 'rb') as src, \
open(main_order, 'ab') as dst:
dst.write(src.read())
for bar in bars.values():
bar.close()
logger.progress('# Data successfully recomputed!')
logger.progress('# Time spent: %.2fs', timeit.default_timer() - tic)
return True
[docs]def find_data(runfolder, ensemble_names=None, data=None):
"""Find the trajectory data used to do post-processing.
find_traj returns a dict with a structure resembling that
of the simulation.
Parameters
----------
runfolder: string, optional
The path of the execution directory.
ensemble_names: list, optional
List of ensemble names in the simulation to work with.
data: string, optional
If given, the function will check only the single file or
look only in the given directory
Returns
-------
trj_dict : dict
To each key, ensemble_name (e.g. 000, 001, etc)
the values are: the last accepted trajectories given by the
`accepted`-key; the generation trajectory or conf files given by
the `generation`-key, and lastly the dictionary `stored_traj`
that is given by the `traj`-key. `stored_traj` is split up into
the dictionaries`traj-acc` and `traj-rej` which have keys for
all the accepted and rejected cycles respectively,
where the trajectory files for that cycle is stored.
"""
trj_dict = {}
flag_map = {'traj-acc': 'ACC', 'traj-rej': 'REJ'}
if data is not None:
abs_data = (data if os.path.isabs(data)
else os.path.join(runfolder, data))
if os.path.isfile(abs_data):
# Single trajectory file
trj_dict['000'] = {'traj': {'0': {'traj': [abs_data]}}}
return trj_dict
if os.path.isdir(abs_data) and os.path.isdir(
os.path.join(abs_data, 'traj')):
# Ensemble folder — use structured loading for this ensemble only
ens_name = os.path.basename(os.path.normpath(abs_data))
ensemble_names = [ens_name]
runfolder = os.path.dirname(os.path.normpath(abs_data))
else:
# Flat folder of trajectory files (legacy behaviour)
sources = _get_trjs(abs_data)
if sources:
trj_dict['000'] = {'traj': {'0': {'traj': sources}}}
return trj_dict
# Structured data
if ensemble_names is None:
ensemble_names = [i.name for i in os.scandir(runfolder) if
(i.is_dir() and i.name.isdigit())]
for ens_name in sorted(ensemble_names):
order_txt = os.path.join(os.path.abspath(runfolder), ens_name,
'order.txt')
trj_dict[ens_name] = {'main_o': order_txt, 'traj': {}}
for trj_type in ['traj-acc', 'traj-rej']:
here = os.path.join(os.path.abspath(runfolder),
ens_name, 'traj', trj_type)
if not os.path.isdir(here):
continue
for cycle in sorted([i.name for i in os.scandir(here) if
i.name.isdigit()]):
loc = os.path.join(here, cycle, 'traj')
o_txt = os.path.join(here, cycle, 'order.txt')
header = f'# Cycle: {cycle},' \
f' status: {flag_map[trj_type]}'
if os.path.exists(o_txt):
with open(o_txt, 'r', encoding='utf-8') as file_in:
header = file_in.readline().replace('\n', '')
trj_dict[ens_name]['traj'][cycle] = {'header': header,
'o_txt': o_txt,
'traj': _get_trjs(loc)}
return trj_dict
[docs]def _get_trjs(runfolder='.'):
"""Find the trajectory files.
Symlinks that resolve to a real file are followed transparently
(cycle-0 trajectories for runs initialised with the ``load`` method
are stored as symlinks pointing back to the load folder). Broken
symlinks are skipped with a warning so the user knows why a cycle
appears empty after moving a simulation tree without its load
folder.
Parameters
----------
runfolder : string, optional
the location of the main simulation folder.
Returns
-------
full_name_trj : list
The trajectory files contained in the folder.
"""
excluded = {'order.txt', 'energy.txt', 'pathensemble.txt', 'error.txt'}
trj = []
for entry in os.scandir(runfolder):
if entry.name[-4:] not in TRJ_FORMATS or entry.name in excluded:
continue
if entry.is_file():
trj.append(entry.name)
elif entry.is_symlink():
logger.warning(
'Broken trajectory symlink: %s -> %s',
entry.path, os.readlink(entry.path),
)
return [os.path.join(runfolder, i) for i in trj]