Source code for pyretis.pyvisa.common

# Copyright (c) 2026, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""
Common functions for the path density.

Functions used to compare and process data, such as matching similar
lists or attempting periodic shifts of values.

Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

find_rst_file (:py:func: `.find_rst_file`)
    Search for a rst-file from a chosen subdirectory.

read_traj_txt_file (:py:func: `.read_traj_txt_file`)
    Read the sequence of files in a trajectory from a traj.txt file.

recalculate_all (:py:func:`.recalculate_all`)
    Recalculate order parameter and new collective variables by finding
    all trajectory files from a simulation.

shift_data (:py:func: `.shift_data`)
    Finds the median value of a given list of floats, and shifts the
    lower half of the data by the median.

try_data_shift (:py:func: `.try_data_shift`)
    Takes in two lists of values, ``x`` and ``y``, and calculates a
    linear regression and R**2-correlation of the data set. Attempts a
    shift of each data set by their respective median to increase the
    correlation.

where_from_to (:py:func: `.where_from_to`)
    Check the initial and final steps of a trajectory with respect to
    the provided interfaces.

get_cv_names (:py:func: `.get_cv_names`)
    Outputs a list of the names of the descriptors in the simulation.

recalculate_all (:py:func: `.recalculate_all`)
    Recompute all the order parameters according to the PyRETIS storage
    scheme or for individual files/folders.

find_data  (:py:func: `.find_data`)
    Find suitable frames/trajectories to recompute the order parameter on.

read_single_order_txt (:py:func: `.read_single_order_txt`)
    Parse a standalone order parameter text file and return a DataFrame
    together with a list of column names suitable for use in PyVisA.

run_user_script (:py:func: `.run_user_script`)
    Execute a user-supplied Python script and capture its stdout output
    to produce an order.txt file that PyVisA can load.

"""
import concurrent.futures
import io
import json
import logging
import os
import re
import shutil
import subprocess  # nosec B404
import sys
import tempfile
import timeit

import tqdm

import numpy as np
import pandas as pd
import scipy  # pylint: disable=import-error

from pyretis.initiation.initiate_load import write_order_parameters
from pyretis.inout import settings
from pyretis.inout.common import create_backup, TRJ_FORMATS
from pyretis.inout.formats.path import PathExtFile
from pyretis.orderparameter import expand_order_names
from pyretis.setup.common import create_orderparameter
from pyretis.tools.recalculate_order import recalculate_order

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

__all__ = ['find_rst_file', 'read_traj_txt_file',
           'shift_data', 'try_data_shift', 'where_from_to',
           'get_cv_names', 'recalculate_all', 'find_data',
           'read_single_data_file', 'read_single_order_txt',
           'run_user_script']


[docs]def try_data_shift(x, y, fixedx): """Check if shifting increases correlation. Function that checks if correlation of data increases by shifting either sets of values, x or y, or both. Correlation is checked by doing a simple linear regression on the different sets of data: - x and y , x and yshift, xshift and y, xshift and yshift. If linear correlation increases (r-squared value), data sets are updated. As a precaution, no shift is performed on x values if they are of the first order parameter 'op1'. Parameters ---------- x, y : list Floats, data values fixedx : bool If True, x is main OP and should not be shifted. Returns ------- x, y : list Floats, updated (or unchanged) data values (If changed, returns x_temp or y_temp or both) """ # The unshifted data _, _, r_val, _, _ = scipy.stats.linregress(x, y) # The Y-shifted data y_temp = shift_data(y) _, _, r_y, _, _ = scipy.stats.linregress(x, y_temp) yshift = r_y**2 > r_val**2 # The X-shifted data x_temp = shift_data(x) _, _, r_x, _, _ = scipy.stats.linregress(x_temp, y) xshift = r_x**2 > r_val**2 and r_x**2 > r_y**2 # Comparing effectiveness of both shifts individually, and combined _, _, r_xy, _, _ = scipy.stats.linregress(x_temp, y_temp) xyshift = r_xy**2 > r_val**2 and r_xy**2 > r_y**2 and r_xy**2 > r_x**2 # If first op is op1, don't shift data if xyshift and not fixedx: return x_temp, y_temp if xshift and not fixedx: return x_temp, y if yshift: return x, y_temp return x, y
[docs]def shift_data(x): """Shifts the data under the median. Function that takes in a list of data, and shifts all values below the median value of the data by the max difference, effectively shifting parts of the data periodically in order to give clusters for visualization. Parameters ---------- x : list Floats, data values Returns ------- xnorm : list Floats where some values are shifted values of x, and some are left unchanged. """ xmin, xmax = min(x), max(x) xnorm = [] # The max difference in x-data diff_x = xmax - xmin # The Median of x-data medix = xmin + 0.5 * diff_x for i in x: if i < medix: xnorm.append(i + diff_x) else: xnorm.append(i) return xnorm
[docs]def read_traj_txt_file(path): """Read a traj.txt file. Function which reads a traj.txt file and returns a dict containing the name of each file in the trajectory and the sign of their velocity. Parameters ---------- path : string Path to the traj.txt file. Returns ------- files : dict Dictionary containing each file in the trajectory and the sign of their velocity. """ files = {} i = 0 with PathExtFile(path, 'r') as pfile: for block in pfile.load(): for data in block['data']: if data[0] == '0': files[i] = [data[1], data[3]] if data[1] != files[i][0]: i += 1 files[i] = [data[1], data[3]] return files
[docs]def find_rst_file(search_dir): """Search for rst-files. Parameters ---------- search_dir : string Path to the .rst file. Returns ------- out[0] : string Path and name of the .rst file. """ current = os.path.abspath(search_dir) while True: try: for file_name in sorted(os.listdir(current)): if file_name.endswith('.rst'): return os.path.join(current, file_name) except OSError: return current parent = os.path.dirname(current) if parent == current: return current current = parent
[docs]def where_from_to(trj, int_a, int_b=float('-inf')): r"""Detect L∕R starts and L / R / \* ends. Given a list of order parameters (a trj), the function will try to establish where the path started (L or R or \*) and where it ended. Note: for the 'REJ' paths, this function results might differ from PyRETIS. Parameters ---------- trj: numpy array The order parameters of the trj. int_a: float The interface that defines state A. int_b: float, optional The interface that defines state B. If not given, it is assumed that the 0^- ensemble is in use without the 0^- L interface. Returns ------- start: string\*1 The initial position of the trajectory in respect to the interfaces given (L eft, R ight or \* for nothing). end: string\*1 The final position of the trajectory in respect to the interfaces given (L eft, R ight or \* for nothing). """ start, end = '*', '*' int_l = min(int_a, int_b) int_t = max(int_a, int_b) if trj[0] >= int_t: start = 'R' if trj[0] < int_l: start = 'L' if trj[-1] >= int_t: end = 'R' if trj[-1] < int_l: end = 'L' return start, end
[docs]def get_cv_names(input_settings, num_columns=None): """Return labels for the order parameter and collective variables. The labels follow the same convention for the main order parameter and any extra collective variable: each ``[orderparameter]`` / ``[collective-variable]`` block may declare ``name`` either as a single string or as a list of strings. * A list of strings is used as-is for that block (one label per value returned by the corresponding ``calculate()``). * A single string is expanded with an index suffix (``"<name>_1", "<name>_2", ...``) to match the number of values the block produces. If the block produces a single value the bare string is used. * When ``name`` is missing the labels fall back to ``"op_<i>"`` for the main order parameter and ``"cv_<i>"`` for each collective variable (using the bare prefix if the block produces a single value). Parameters ---------- input_settings : dict Dictionary with the settings from the simulations. num_columns : int, optional Total number of order-parameter columns observed (e.g. read from ``order.txt``). When supplied, this is used to allocate unspecified blocks: if there is exactly one section configured and ``num_columns`` is greater than one, the single ``name`` is expanded with indices to cover all columns. When ``num_columns`` is given and cannot be reconciled with the configured names, a :py:class:`ValueError` is raised. Returns ------- names : list of str Flat list of column labels, in the same order as the concatenated ``calculate()`` outputs. Raises ------ ValueError If a ``name`` list does not match the number of values produced by its block, or if ``num_columns`` cannot be reconciled with the configured names. """ blocks = [] main_block = input_settings.get('orderparameter') if main_block is not None: blocks.append(('op', main_block, main_block.get('name'))) cv_blocks = input_settings.get('collective-variable', []) for idx, c_v in enumerate(cv_blocks, start=1): prefix = 'cv' if len(cv_blocks) == 1 else f'cv{idx}' blocks.append((prefix, c_v, c_v.get('name'))) if not blocks: return [] # When the only information available is the column count from a # file (e.g. order.txt) and a single section is configured, that # section claims all columns. if num_columns is not None and len(blocks) == 1: prefix, _, name = blocks[0] return expand_order_names(name, num_columns, default_prefix=prefix) labels = [] for prefix, _, name in blocks: if isinstance(name, (list, tuple)): labels.extend(expand_order_names(name, len(name), prefix)) else: labels.extend(expand_order_names(name, 1, prefix)) if num_columns is not None and len(labels) != num_columns: raise ValueError( "Configured order-parameter names produce " f"{len(labels)} labels, but the data has {num_columns} " "columns. Provide a list ``name`` for any multi-valued " "order parameter or collective variable." ) return labels
[docs]def _split_table_fields(line): """Split a text row using comma or whitespace delimiters.""" text = line.strip() if not text: return [] if ',' in text: return [part.strip() for part in text.split(',') if part.strip()] return text.split()
[docs]def _parse_numeric_row(line): """Return numeric values from a row, or ``None`` if parsing fails.""" fields = _split_table_fields(line) if not fields: return None try: return [float(value) for value in fields] except ValueError: return None
[docs]def _parse_header_candidate(line): """Extract a possible header row and whether it was commented.""" stripped = line.strip() if not stripped: return None, False numeric = _parse_numeric_row(stripped) if numeric is not None: return None, False is_comment = re.match(r'^[^A-Za-z0-9.+-]+', stripped) is not None cleaned = re.sub(r'^[^A-Za-z0-9]+', '', stripped) fields = _split_table_fields(cleaned) return (fields or None), is_comment
[docs]def _labels_from_rst(rst_file, n_cols): """Infer column labels from an optional PyRETIS ``.rst`` file.""" if not rst_file or not os.path.isfile(rst_file): return None try: input_settings = settings.parse_settings_file( os.path.abspath(rst_file) ) except (OSError, ValueError, KeyError): return None labels = get_cv_names(input_settings) if n_cols == len(labels): return labels if n_cols > 1 and n_cols == len(labels) + 1: return ['time'] + labels return None
[docs]def _select_header_labels(comment_headers, plain_headers, n_cols): """Choose the most plausible header candidate for a text table.""" for candidates in (comment_headers, plain_headers): for header in candidates: if len(header) == n_cols: return header if n_cols > 1 and len(header) == n_cols - 1: return ['time'] + header return None
[docs]def read_single_data_file(filepath, rst_file=None): """Parse a standalone ``txt`` or ``csv`` data file for PyVisA. The file may contain: * numeric rows separated by whitespace or commas, * an optional commented header line (``#``, ``##``, ``;``, ``//``, etc.), * or an uncommented CSV header row. When no usable header is present, column titles are inferred from *rst_file* when possible. Otherwise, PyVisA falls back to ``time``, ``op1``, ``op2``, ... Parameters ---------- filepath : str Path to the text/CSV file. rst_file : str, optional Optional PyRETIS ``.rst`` file used to infer missing labels. Returns ------- frames : pandas.DataFrame or None Parsed numeric data with every column preserved, including the first time/step column when present. plot_cols : list of str or None Column labels available for plotting. main_op_label : str or None Best-effort main order-parameter label for interface/range logic. """ rows = [] comment_headers = [] plain_headers = [] with open(filepath, encoding='utf-8') as fh: for line in fh: stripped = line.strip() if not stripped: continue numeric = _parse_numeric_row(stripped) if numeric is not None: rows.append(numeric) continue header, is_comment = _parse_header_candidate(stripped) if header: if is_comment: comment_headers.append(header) else: plain_headers.append(header) if not rows: return None, None, None n_cols = len(rows[0]) rows = [row for row in rows if len(row) == n_cols] if not rows: return None, None, None col_names = _select_header_labels(comment_headers, plain_headers, n_cols) if col_names is None: col_names = _labels_from_rst(rst_file, n_cols) if col_names is None: if n_cols == 1: col_names = ['op1'] else: col_names = ['time'] + [f'op{i}' for i in range(1, n_cols)] frames = pd.DataFrame(rows, columns=col_names, dtype=float) time_like = col_names[0].lower() in ('time', 't', 'step', 'cycle') if time_like and len(col_names) > 1: main_op_label = col_names[1] else: main_op_label = col_names[0] return frames, col_names, main_op_label
[docs]def read_single_order_txt(filepath, rst_file=None): """Backward-compatible wrapper for standalone text-table loading.""" frames, plot_cols, _ = read_single_data_file(filepath, rst_file=rst_file) return frames, plot_cols
[docs]def run_user_script(script_path): # pylint: disable=too-many-return-statements,too-many-locals """Execute a user script and capture order-parameter data from stdout. The script must print its result to *stdout* in one of two formats: * **JSON list of lists** – ``[[op1_frame0, op2_frame0, …], …]`` where each inner list contains the order-parameter values for one simulation frame. * **CSV table** – a comma- or whitespace-delimited table that ``pandas.read_csv`` can parse (one row per frame, optional header). The captured data are written to ``order.txt`` in the same directory as *script_path* so that PyVisA can subsequently load the file. Parameters ---------- script_path : str Absolute path to the Python script to execute. Returns ------- order_txt_path : str or None Absolute path of the written ``order.txt`` file, or ``None`` on failure. error : str or None Human-readable error description, or ``None`` on success. """ script_dir = os.path.dirname(os.path.abspath(script_path)) try: result = subprocess.run( # nosec B603 [sys.executable, script_path], cwd=script_dir, capture_output=True, text=True, timeout=600, check=False, ) except subprocess.TimeoutExpired: return None, 'Script timed out after 600 seconds.' except OSError as exc: return None, str(exc) if result.returncode != 0: return None, ( f'Script exited with code {result.returncode}:\n{result.stderr}' ) stdout = result.stdout.strip() if not stdout: return None, 'Script produced no output on stdout.' # Try JSON first, then fall back to CSV rows = None try: data = json.loads(stdout) if isinstance(data, list): rows = [list(r) for r in data] except json.JSONDecodeError: pass if rows is None: try: df_out = pd.read_csv( io.StringIO(stdout), header=None, sep=r'\s*,\s*|\s+', engine='python', ) rows = df_out.values.tolist() except Exception: # pylint: disable=broad-except return None, ( 'Could not parse script output as JSON list of lists or CSV.' ) if not rows: return None, 'Script returned empty data.' n_cols = len(rows[0]) if rows else 0 op_cols = ['Orderp'] + [f'cv{j}' for j in range(1, n_cols)] order_txt_path = os.path.join(script_dir, 'order.txt') with open(order_txt_path, 'w', encoding='utf-8') as fh: header_cols = ' '.join(f'{c:>14}' for c in op_cols) fh.write('Recalculated data\n') fh.write(f'#{"Time":>12} {header_cols}\n') for i, row in enumerate(rows): vals = ' '.join(f'{float(v):14.6f}' for v in row) fh.write(f'{i:12d} {vals}\n') # Verify we can read it back frames, _, _ = read_single_data_file(order_txt_path) if frames is None: return None, 'order.txt was written but could not be re-parsed.' _ = np.asarray(frames) # trigger numpy import for early error detection return order_txt_path, None
# 50 MB threshold for trajectory splitting _LARGE_TRJ_BYTES = 50 * 1024 * 1024 # Minimum frames per chunk when splitting _MIN_CHUNK_FRAMES = 30
[docs]def _count_gromacs_frames(filename, gmx_exe='gmx'): """Return (n_frames, dt_ps) for a GROMACS trajectory via gmx check. Parses both values from the same summary table row so a single subprocess call is sufficient and ``gmx dump`` is never needed. Parameters ---------- filename : str Path to the .trr or .xtc trajectory file. gmx_exe : str, optional Path to the GROMACS executable. Returns ------- n_frames : int or None dt_ps : float or None Frame count and timestep in ps, or (None, None) on failure. """ try: result = subprocess.run( # nosec B603 [gmx_exe, 'check', '-f', filename], capture_output=True, text=True, timeout=120, check=False, ) output = result.stdout + result.stderr for line in output.splitlines(): # GROMACS 2019+: "Step <n_frames> <dt_ps>" summary table m = re.match(r'\s*Step\s+(\d+)\s+([\d.eE+\-]+)', line) if m: n = int(m.group(1)) dt = float(m.group(2)) if n > 0 and dt > 0: return n, dt # Older GROMACS: "Highest frame number: N" (0-based), no dt m = re.search(r'[Hh]ighest frame.*?(\d+)', line) if m: return int(m.group(1)) + 1, None # Parsing succeeded but no frame count found — log the raw output logger.warning( 'gmx check ran but frame count not found in output for %s ' '(gmx: %s). Raw output:\n%s', os.path.basename(filename), gmx_exe, output[:500]) except FileNotFoundError: logger.warning( 'gmx executable not found: "%s". ' 'Cannot split large trajectories.', gmx_exe) except (subprocess.SubprocessError, subprocess.TimeoutExpired) as exc: logger.warning( 'gmx check failed for %s: %s', os.path.basename(filename), exc) return None, None
[docs]def _split_large_trajectory(trj, n_chunks, engine_info): """Split a large GROMACS trajectory into sub-trajectory files. Uses a single ``gmx trjconv -split`` pass to write all chunks. The caller must clean up the returned *tmpdir* when done. Parameters ---------- trj : str Path to the source .trr or .xtc file. n_chunks : int Desired number of output chunks. engine_info : dict Must contain a ``'gmx'`` key with the GROMACS executable path. Returns ------- tmpdir : str or None Temporary directory holding the chunk files, or None on failure. chunks : list of str or None Sorted paths to the chunk files, or None on failure. """ ext = os.path.splitext(trj)[1].lower() if ext not in ('.trr', '.xtc'): return None, None gmx_exe = engine_info.get('gmx', 'gmx') n_frames, dt = _count_gromacs_frames(trj, gmx_exe) if not n_frames or n_frames < 2: logger.warning( ' Cannot split %s: frame count unavailable or too few frames ' '(got %s). Falling back to single task.', os.path.basename(trj), n_frames) return None, None if dt is None or dt <= 0: logger.warning( ' Cannot split %s: timestep not available from gmx check ' '(got dt=%s). Falling back to single task.', os.path.basename(trj), dt) return None, None # Cap chunks at n_workers; ensure each chunk has at least # _MIN_CHUNK_FRAMES frames. n_chunks = min(n_chunks, n_frames // _MIN_CHUNK_FRAMES) if n_chunks < 2: logger.warning( ' Cannot split %s: only %d frames, need at least %d for 2 ' 'chunks of %d frames. Falling back to single task.', os.path.basename(trj), n_frames, 2 * _MIN_CHUNK_FRAMES, _MIN_CHUNK_FRAMES) return None, None chunk_frames = n_frames // n_chunks split_time = chunk_frames * dt tmpdir = tempfile.mkdtemp(prefix='pyvisa_split_') # Write an index file selecting all frames (1-based for GROMACS). ndx_file = os.path.join(tmpdir, 'all_frames.ndx') with open(ndx_file, 'w', encoding='utf-8') as ndxf: ndxf.write('[ frames ]\n') for idx in range(n_frames): ndxf.write(f'{idx + 1}\n') output_base = os.path.join(tmpdir, f'seg{ext}') subprocess.run( # nosec B603 [gmx_exe, 'trjconv', '-f', trj, '-o', output_base, '-fr', ndx_file, '-split', str(split_time)], input=b'0\n', stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=False, ) # Collect and sort output chunks by the trailing number GROMACS adds. def _chunk_sort_key(path): nums = re.findall(r'(\d+)', os.path.basename(path).replace(ext, '')) return int(nums[-1]) if nums else 0 chunks = sorted( (os.path.join(tmpdir, f) for f in os.listdir(tmpdir) if f.startswith('seg') and f.endswith(ext)), key=_chunk_sort_key, ) if not chunks: logger.warning( ' gmx trjconv -split produced no output for %s ' '(gmx: %s, split_time: %.4f ps). Falling back to single task.', os.path.basename(trj), gmx_exe, split_time) shutil.rmtree(tmpdir, ignore_errors=True) return None, None return tmpdir, chunks
[docs]def _recalculate_trajectory(task): """Process one trajectory file (used by ProcessPoolExecutor). One task = one trajectory file so the main process can update the per-ensemble bar after every single trajectory completes. Parameters ---------- task : dict Keys: ``ens_name``, ``cycle_key``, ``trj_idx``, ``trj``, ``input_settings``. Returns ------- ens_name : str cycle_key : str trj_idx : int Position of this trajectory within its cycle (used to reassemble results in the original order). results : list Order-parameter values computed from this trajectory. ok : bool warnings : list of str """ # Silence sub-WARNING logging in workers so output doesn't interleave # with the tqdm bars in the main process. logging.disable(logging.WARNING - 1) ens_name = task['ens_name'] cycle_key = task['cycle_key'] trj_idx = task['trj_idx'] chunk_idx = task.get('chunk_idx', 0) trj = task['trj'] warnings_out = [] results = [] try: functions = create_orderparameter(task['input_settings']) except (ImportError, ValueError) as exc: return (ens_name, cycle_key, trj_idx, chunk_idx, results, False, [str(exc)]) try: for order_p in recalculate_order(functions, trj, {}): results.append(order_p) except (KeyError, AttributeError): warnings_out.append(f'File {trj} not valid') return ens_name, cycle_key, trj_idx, chunk_idx, results, True, warnings_out
[docs]def recalculate_all( # pylint: disable=too-many-locals runfolder, iofile, ensemble_names=None, data=None, progress=False, n_workers=1): """Recalculate order parameter and collective variables. Performs post-processing by analyzing trajectories of old simulations to extract data, do new calculations, and write to a new order.txt file. Parameters ---------- runfolder: string The path of the execution directory. iofile: string The input file where the settings are collected. ensemble_names: list, optional List of ensemble names in the simulation to work with. data: string, optional If given, the function will check only the single file or look only in the given directory. progress : bool, optional If True, display tqdm progress bars. n_workers : int, optional Number of parallel worker processes. Values > 1 process ensembles concurrently via :class:`ProcessPoolExecutor`. Returns ------- out : boolean True if the recomputation was successful, False otherwise. """ if iofile is None: raise ValueError('Input file not given') input_settings = settings.parse_settings_file( os.path.join(runfolder, iofile) ) trj_dict = find_data(runfolder, ensemble_names, data=data) logger.progress('Re-computing the collective variables.') tic = timeit.default_timer() if not trj_dict: logger.warning( 'No data to re-compute from in %s — no trajectory data found. ' 'Ensure the run folder contains numeric ensemble subdirectories ' '(e.g. 000, 001) or pass the `data` argument pointing to a ' 'file or directory to re-process.', runfolder) return False n_workers = max(1, n_workers) # Log order-parameter info once in the main process, before any bars. try: functions = create_orderparameter(input_settings) except (ImportError, ValueError): logger.warning('Invalid Order Parameter') # pragma: no cover return False # pragma: no cover def _make_bars(ens_names_counts): """Create one tqdm bar per ensemble, sized by trajectory count.""" return { ens_name: tqdm.tqdm( total=n_trajs, desc=f'Ens {ens_name}', unit='traj', position=i, leave=True, disable=not progress, ) for i, (ens_name, n_trajs) in enumerate(ens_names_counts) } if n_workers > 1: gmx_exe = input_settings.get('engine', {}).get('gmx', 'gmx') engine_info = {'gmx': gmx_exe} logger.progress('GROMACS executable for splitting: "%s"', gmx_exe) # Initialise per-ensemble metadata; pending counts filled below. ens_meta = {} for ens_name, ens in trj_dict.items(): cycles_items = [(k, v) for k, v in sorted(ens['traj'].items()) if v.get('traj')] ens_meta[ens_name] = { 'main_o': ens.get('main_o'), 'cycle_orders': {}, 'cycles': { cycle_key: { 'meta': cycles, 'pending': 0, 'results': {}, # {trj_idx: {chunk_idx: list}} 'tmpdirs': {}, # {trj_idx: tmpdir} for split trajs } for cycle_key, cycles in cycles_items }, } # Collect every trajectory with its file size, then sort largest # first so the heaviest work is always scheduled first. all_trj_info = [] for ens_name, ens in trj_dict.items(): for cycle_key, cycles in sorted(ens['traj'].items()): if not cycles.get('traj'): continue for trj_idx, trj in enumerate(cycles['traj']): size = os.path.getsize(trj) if os.path.isfile(trj) else 0 all_trj_info.append( (size, ens_name, cycle_key, trj_idx, trj)) all_trj_info.sort(key=lambda x: x[0], reverse=True) # Build task groups. Each large trajectory becomes its own group # so all workers focus on it before moving to the next one. # Small trajectories share a single batch at the end. split_groups = [] # list of task-lists, one per large trajectory normal_tasks = [] # all non-split trajectories for size, ens_name, cycle_key, trj_idx, trj in all_trj_info: cycle_data = ens_meta[ens_name]['cycles'][cycle_key] ext = os.path.splitext(trj)[1].lower() chunk_paths = None if size > _LARGE_TRJ_BYTES and ext in ('.trr', '.xtc'): logger.progress( 'Large trajectory detected: %s (%.0f MB) — ' 'splitting into chunks, please wait...', os.path.basename(trj), size / 1024 / 1024) tmpdir, chunk_paths = _split_large_trajectory( trj, n_workers, engine_info) if chunk_paths: cycle_data['tmpdirs'][trj_idx] = tmpdir logger.progress( ' Done — %d chunks ready.', len(chunk_paths)) if chunk_paths: group = [ {'ens_name': ens_name, 'cycle_key': cycle_key, 'trj_idx': trj_idx, 'chunk_idx': ci, 'is_split': True, 'trj': cp, 'input_settings': input_settings} for ci, cp in enumerate(chunk_paths) ] split_groups.append(group) cycle_data['pending'] += len(chunk_paths) else: normal_tasks.append({ 'ens_name': ens_name, 'cycle_key': cycle_key, 'trj_idx': trj_idx, 'chunk_idx': 0, 'trj': trj, 'input_settings': input_settings, }) cycle_data['pending'] += 1 bars = _make_bars([ (ens_name, sum(c['pending'] for c in meta['cycles'].values())) for ens_name, meta in ens_meta.items() ]) def _accumulate(fut_result): """Incorporate one completed future into ens_meta and bars.""" ens_n, cyc_k, t_idx, c_idx, res, ok, warns = fut_result for w in warns: logger.warning('%s cycle %s: %s', ens_n, cyc_k, w) if not ok: logger.warning( 'Ensemble %s cycle %s trajectory %d failed', ens_n, cyc_k, t_idx) bars[ens_n].update(1) cd = ens_meta[ens_n]['cycles'][cyc_k] if t_idx not in cd['results']: cd['results'][t_idx] = {} cd['results'][t_idx][c_idx] = res cd['pending'] -= 1 if cd['pending'] == 0: cycs = cd['meta'] here = os.path.dirname(os.path.abspath(cycs['traj'][0])) new_o = os.path.join(here, 'order.txt') results_dict = [] for ti in sorted(cd['results']): for ci2 in sorted(cd['results'][ti]): results_dict.extend(cd['results'][ti][ci2]) local_order = cycs.get('o_txt', new_o) create_backup(local_order) write_order_parameters( local_order, results_dict, cycs.get('header', 'Recalculated data')) ens_meta[ens_n]['cycle_orders'][cyc_k] = local_order for td in cd['tmpdirs'].values(): shutil.rmtree(td, ignore_errors=True) with concurrent.futures.ProcessPoolExecutor( max_workers=n_workers) as executor: # Each large trajectory is its own barrier: all workers finish # it before the next one starts (largest → smallest). for group in split_groups: futures = [executor.submit(_recalculate_trajectory, t) for t in group] for future in concurrent.futures.as_completed(futures): _accumulate(future.result()) # All remaining small trajectories run together in one batch. if normal_tasks: futures = [executor.submit(_recalculate_trajectory, t) for t in normal_tasks] for future in concurrent.futures.as_completed(futures): _accumulate(future.result()) for bar in bars.values(): bar.close() # Merge per-cycle order.txt into each ensemble's main_order # in sorted cycle order (ordering is lost during parallel execution). for ens_name, meta in ens_meta.items(): main_order = meta['main_o'] cycle_orders = meta['cycle_orders'] if main_order is None or not cycle_orders: continue create_backup(main_order) with open(main_order, 'ab') as dst: for cycle_key in sorted(cycle_orders): with open(cycle_orders[cycle_key], 'rb') as src: dst.write(src.read()) else: ens_counts = [ (ens_name, sum(len(v.get('traj', [])) for v in ens['traj'].values())) for ens_name, ens in trj_dict.items() ] bars = _make_bars(ens_counts) for ens_name, ens in trj_dict.items(): main_order = ens.get('main_o') main_backed_up = False for _, cycles in sorted(ens['traj'].items()): if not cycles.get('traj'): continue here = os.path.dirname(os.path.abspath(cycles['traj'][0])) new_o = os.path.join(here, 'order.txt') results_dict = [] for trj in cycles['traj']: try: for order_p in recalculate_order(functions, trj, {}): results_dict.append(order_p) # pragma: no cover except (KeyError, AttributeError): logger.warning('File %s not valid', trj) bars[ens_name].update(1) local_order = cycles.get('o_txt', new_o) create_backup(local_order) write_order_parameters(local_order, results_dict, cycles.get('header', 'Recalculated data')) if main_order is not None: if not main_backed_up: create_backup(main_order) main_backed_up = True with open(local_order, 'rb') as src, \ open(main_order, 'ab') as dst: dst.write(src.read()) for bar in bars.values(): bar.close() logger.progress('# Data successfully recomputed!') logger.progress('# Time spent: %.2fs', timeit.default_timer() - tic) return True
[docs]def find_data(runfolder, ensemble_names=None, data=None): """Find the trajectory data used to do post-processing. find_traj returns a dict with a structure resembling that of the simulation. Parameters ---------- runfolder: string, optional The path of the execution directory. ensemble_names: list, optional List of ensemble names in the simulation to work with. data: string, optional If given, the function will check only the single file or look only in the given directory Returns ------- trj_dict : dict To each key, ensemble_name (e.g. 000, 001, etc) the values are: the last accepted trajectories given by the `accepted`-key; the generation trajectory or conf files given by the `generation`-key, and lastly the dictionary `stored_traj` that is given by the `traj`-key. `stored_traj` is split up into the dictionaries`traj-acc` and `traj-rej` which have keys for all the accepted and rejected cycles respectively, where the trajectory files for that cycle is stored. """ trj_dict = {} flag_map = {'traj-acc': 'ACC', 'traj-rej': 'REJ'} if data is not None: abs_data = (data if os.path.isabs(data) else os.path.join(runfolder, data)) if os.path.isfile(abs_data): # Single trajectory file trj_dict['000'] = {'traj': {'0': {'traj': [abs_data]}}} return trj_dict if os.path.isdir(abs_data) and os.path.isdir( os.path.join(abs_data, 'traj')): # Ensemble folder — use structured loading for this ensemble only ens_name = os.path.basename(os.path.normpath(abs_data)) ensemble_names = [ens_name] runfolder = os.path.dirname(os.path.normpath(abs_data)) else: # Flat folder of trajectory files (legacy behaviour) sources = _get_trjs(abs_data) if sources: trj_dict['000'] = {'traj': {'0': {'traj': sources}}} return trj_dict # Structured data if ensemble_names is None: ensemble_names = [i.name for i in os.scandir(runfolder) if (i.is_dir() and i.name.isdigit())] for ens_name in sorted(ensemble_names): order_txt = os.path.join(os.path.abspath(runfolder), ens_name, 'order.txt') trj_dict[ens_name] = {'main_o': order_txt, 'traj': {}} for trj_type in ['traj-acc', 'traj-rej']: here = os.path.join(os.path.abspath(runfolder), ens_name, 'traj', trj_type) if not os.path.isdir(here): continue for cycle in sorted([i.name for i in os.scandir(here) if i.name.isdigit()]): loc = os.path.join(here, cycle, 'traj') o_txt = os.path.join(here, cycle, 'order.txt') header = f'# Cycle: {cycle},' \ f' status: {flag_map[trj_type]}' if os.path.exists(o_txt): with open(o_txt, 'r', encoding='utf-8') as file_in: header = file_in.readline().replace('\n', '') trj_dict[ens_name]['traj'][cycle] = {'header': header, 'o_txt': o_txt, 'traj': _get_trjs(loc)} return trj_dict
[docs]def _get_trjs(runfolder='.'): """Find the trajectory files. Symlinks that resolve to a real file are followed transparently (cycle-0 trajectories for runs initialised with the ``load`` method are stored as symlinks pointing back to the load folder). Broken symlinks are skipped with a warning so the user knows why a cycle appears empty after moving a simulation tree without its load folder. Parameters ---------- runfolder : string, optional the location of the main simulation folder. Returns ------- full_name_trj : list The trajectory files contained in the folder. """ excluded = {'order.txt', 'energy.txt', 'pathensemble.txt', 'error.txt'} trj = [] for entry in os.scandir(runfolder): if entry.name[-4:] not in TRJ_FORMATS or entry.name in excluded: continue if entry.is_file(): trj.append(entry.name) elif entry.is_symlink(): logger.warning( 'Broken trajectory symlink: %s -> %s', entry.path, os.readlink(entry.path), ) return [os.path.join(runfolder, i) for i in trj]