"""Setup for the infinite-swapping scheduler.
Reads the TOML config, builds the replica-exchange (REPEX) state, and
hands the worker pool to
:py:class:`pyretis.simulation.async_runner.aiorunner`.
"""
import logging
import os
from typing import Optional, Tuple
import tomli
from pyretis.core.path_load import load_paths_from_disk
from pyretis.core._tis_inf import run_md
from pyretis.engines.factory import create_engines
from pyretis.inout._formatter_inf import get_log_formatter
from pyretis.simulation.async_runner import (
aiorunner,
future_list,
get_worker_engines,
)
from pyretis.simulation.repex import InfSwapState
from pyretis.inout.native_output import (
init_native_pathensemble_files,
native_output_requested,
)
logger = logging.getLogger("main")
logger.setLevel(logging.DEBUG)
[docs]class TOMLConfigError(Exception):
"""Raised when there is an error in the .toml configuration."""
pass
# def __init__(self, message):
# super().__init__(message)
[docs]def setup_internal(config: dict) -> Tuple[dict, InfSwapState]:
"""Run the various setup functions.
Args
config: the configuration dictionary
Returns:
A blank md_items dict
An initialized REPEX state
"""
# setup logger
setup_logger()
# setup repex
state = InfSwapState(config, minus=True)
# setup ensembles
state.initiate_ensembles()
# load paths from disk and add to repex
paths = load_paths_from_disk(config)
state.load_paths(paths)
# On the native-config route, initialise the per-ensemble
# pathensemble.txt files (headers + directories). Only at a fresh
# start; on restart we keep appending to the existing files.
if native_output_requested(config):
if "restarted_from" not in config["current"]:
init_native_pathensemble_files(state)
# create first md_items dict
md_items = {
"mc_moves": state.mc_moves,
"interfaces": state.interfaces,
"cap": state.cap,
}
# setup the engine_occupation list
_, engine_occ = create_engines(config)
state.engine_occ = engine_occ
return md_items, state
[docs]def _run_md_task(md_items: dict) -> dict:
"""Worker-side entry point for the runner.
Runs in the worker process and injects that worker's process-local
engine pool into :func:`pyretis.core._tis_inf.run_md`, so the runner
stays engine-agnostic and ``run_md`` takes its engines explicitly
rather than reading a module global.
"""
return run_md(md_items, get_worker_engines())
[docs]def setup_runner(state: InfSwapState) -> Tuple[aiorunner, future_list]:
"""Set the task runner class up.
Args:
state: A REPEX state from which to get the config dict
"""
# setup client with state.workers workers
runner = aiorunner(state.config, state.config["runner"]["workers"])
# Attach the run_md task and start the runner's workers
runner.set_task(_run_md_task)
runner.start()
# A managed list of futures
futures = future_list()
return runner, futures
[docs]def setup_config(
inp: str = "infswap.toml", re_inp: str = "restart.toml"
) -> Optional[dict]:
"""Set dict from a TOML file up.
Arg
inp: a string specifying the input file (def: infswap.toml)
re_inp: a string specifying the restart file (def: restart.toml)
Return
A dictionary containing the configuration parameters or None
"""
# sets up the dict from a TOML file.
# load input:
if os.path.isfile(inp):
with open(inp, mode="rb") as read:
config = tomli.load(read)
else:
logger.info("%s file not found, exit.", inp)
return None
# check if restart.toml exist
if inp != re_inp and os.path.isfile(re_inp):
msg = f"Restart file '{re_inp}' found, but its not the run file!"
raise ValueError(msg)
# in case we restart, toml file has a 'current' subdict.
if "current" in config:
curr = config["current"]
# if cstep and steps are equal, we stop here.
if curr.get("cstep") == curr.get("restarted_from", -1):
return None
# set 'restarted_from'
curr["restarted_from"] = config["current"]["cstep"]
# check active paths:
load_dir = config["simulation"].get("load_dir", "trajs")
for act in config["current"]["active"]:
store_p = os.path.join(load_dir, str(act), "traj.txt")
if not os.path.isfile(store_p):
return None
else:
# no 'current' in toml, start from step 0.
size = len(config["simulation"]["interfaces"])
config["current"] = {
"traj_num": size,
"cstep": 0,
"active": list(range(size)),
"locked": [],
"size": size,
"frac": {},
"wsubcycles": [0 for _ in range(config["runner"]["workers"])],
"tsubcycles": 0,
}
# write/overwrite infswap_data.txt. Skipped on the native-output
# route: there the coordinator emits native per-ensemble
# pathensemble.txt (see pyretis.inout.native_output) and the
# inf-format data file would be a stray non-native artifact.
if not native_output_requested(config):
write_header(config)
# quantis or any other method requiring different engines in each ensemble
has_ens_engs = config["simulation"].get("ensemble_engines", False)
if not has_ens_engs:
ens_engs = []
for itnf in config["simulation"]["interfaces"]:
ens_engs.append(["engine"])
config["simulation"]["ensemble_engines"] = ens_engs
# set all keywords only once, so they appear in restart.toml
# and we can avoid the .get() in other parts
if "seed" not in config["simulation"].keys():
config["simulation"]["seed"] = 0
# [simulation] defaults
config["simulation"].setdefault("load_dir", "load")
config["simulation"].setdefault("zeroswap", 0.5)
config["simulation"].setdefault("pick_scheme", 0)
# [simulation.tis_set] defaults
config["simulation"]["tis_set"].setdefault("quantis", False)
config["simulation"]["tis_set"].setdefault("lambda_minus_one", False)
config["simulation"]["tis_set"].setdefault("accept_all", False)
# we do not set default interface_cap here, it defaults to
# interfaces[-1] in wf already.
# [output] defaults
config["output"].setdefault("keep_maxop_trajs", False)
config["output"].setdefault("delete_old", False)
config["output"].setdefault("delete_old_all", False)
# validation for output settings
keep_maxop_trajs = config["output"]["keep_maxop_trajs"]
delete_old = config["output"]["delete_old"]
delete_old_all = config["output"]["delete_old_all"]
if not delete_old and keep_maxop_trajs:
raise TOMLConfigError("keep_maxop_trajs=True requires delete_old=True")
if delete_old_all and keep_maxop_trajs:
msg = (
"delete_old_all=True will delete all trajectories. Set "
"keep_maxop_trajs to False in the [output] section"
)
raise TOMLConfigError(msg)
# handle quantis configuration
quantis = config["simulation"]["tis_set"]["quantis"]
if quantis and not has_ens_engs:
config["simulation"]["ensemble_engines"][0] = ["engine0"]
check_config(config)
return config
[docs]def check_config(config: dict) -> None:
"""Perform some checks on the settings from the .toml file.
Args
config: the configuration dictionary
"""
intf = config["simulation"]["interfaces"]
n_ens = len(config["simulation"]["interfaces"])
n_workers = config["runner"]["workers"]
sh_moves = config["simulation"]["shooting_moves"]
n_sh_moves = len(sh_moves)
lambda_minus_one = config["simulation"]["tis_set"]["lambda_minus_one"]
intf_cap = config["simulation"]["tis_set"].get("interface_cap", False)
if lambda_minus_one is not False and lambda_minus_one >= intf[0]:
raise TOMLConfigError(
"lambda_minus_one interface must be less than the first interface!"
)
if n_ens < 2:
raise TOMLConfigError("Define at least 2 interfaces!")
if n_workers > n_ens - 1:
raise TOMLConfigError("Too many workers defined!")
if sorted(intf) != intf:
raise TOMLConfigError("Your interfaces are not sorted!")
if len(set(intf)) != len(intf):
raise TOMLConfigError("Your interfaces contain duplicate values!")
if n_ens > n_sh_moves:
raise TOMLConfigError(
f"N_interfaces {n_ens} > N_shooting_moves {n_sh_moves}!"
)
if intf_cap and intf_cap > intf[-1]:
raise TOMLConfigError(
f"Interface_cap {intf_cap} > interface[-1]={intf[-1]}"
)
if intf_cap and intf_cap < intf[0]:
raise TOMLConfigError(
f"Interface_cap {intf_cap} < interface[-2]={intf[-2]}"
)
# engine checks
unique_engines = []
for engines in config["simulation"]["ensemble_engines"]:
for engine in engines:
if engine not in unique_engines:
unique_engines.append(engine)
for key1 in unique_engines:
if key1 not in config.keys():
raise TOMLConfigError(f"Engine '{key1}' not defined!")
# gromacs check
for key1 in unique_engines:
if config[key1]["class"] == "gromacs":
eng1 = config[key1].copy()
inp_path1 = eng1.pop("input_path")
for key2 in unique_engines:
eng2 = config[key2].copy()
inp_path2 = eng2.pop("input_path")
if eng1 != eng2 and inp_path1 == inp_path2:
raise TOMLConfigError(
"Found differing engine settings with identic"
+ "al 'input_path'. This would overwrite the"
+ " settings of one of the engines in"
+ " 'pyretis.mdp'!"
)
# check wsubcycles and tsubcycles in case restarting from old version
if "wsubcycles" not in config["current"]:
list_of_zeros = [0 for _ in range(config["runner"]["workers"])]
config["current"]["wsubcycles"] = list_of_zeros
if "tsubcycles" not in config["current"]:
config["current"]["tsubcycles"] = 0
# if increased number of workers
wsub_num = len(config["current"]["wsubcycles"])
if wsub_num < config["runner"]["workers"]:
extra = config["runner"]["workers"] - wsub_num
config["current"]["wsubcycles"] += [0] * extra
[docs]def setup_logger(inp: str = "sim.log") -> None:
"""Set main logger.
Args
inp: a string specifying the main log file
"""
# Define a console logger. This will log to sys.stderr:
console = logging.StreamHandler()
console.setLevel(logging.WARNING)
console.setFormatter(get_log_formatter(logging.WARNING))
logger.addHandler(console)
fileh = logging.FileHandler(inp, mode="a")
log_levl = getattr(logging, "info".upper(), logging.INFO)
fileh.setLevel(log_levl)
fileh.setFormatter(get_log_formatter(log_levl))
logger.addHandler(fileh)