Source code for pyretis.simulation.async_runner

"""asyncio-based task runner for the infinite-swapping scheduler.

The runner drives the worker pool that executes the per-cycle
``run_md`` calls produced by
:py:mod:`pyretis.simulation.scheduler`.
"""

import asyncio
import concurrent.futures
import functools
import logging
import multiprocessing
import threading
import time
from asyncio import Future
from collections.abc import Callable
from typing import Any, Dict, List, Optional

from pyretis.setup.common import create_orderparameters
from pyretis.engines.factory import create_engines
from pyretis.inout._formatter_inf import get_log_formatter

logger = logging.getLogger("")
logger.setLevel(logging.DEBUG)


[docs]class _ShutdownNoiseFilter(logging.Filter): """Drop the expected ``BrokenProcessPool`` teardown noise. When the worker pool is force-released on an interrupt / SIGTERM (a tutorial smoke run killed by its timeout, an HPC walltime ``kill``, a user Ctrl-C), an in-flight ``run_in_executor`` future can finish with a :class:`concurrent.futures.BrokenExecutor` *after* its wrapper task was already cancelled. asyncio then reports it -- usually during interpreter shutdown -- as "Future exception was never retrieved" at ERROR level with a full traceback. That is expected teardown noise, not a real failure (a genuine pool failure mid-run is retrieved onto the work future via ``future.set_exception`` and never reaches this "never retrieved" path), so it is dropped. Everything else passes through untouched. """
[docs] def filter(self, record: logging.LogRecord) -> bool: """Return False only for the unretrieved-BrokenExecutor record.""" if "Future exception was never retrieved" not in record.getMessage(): return True exc = record.exc_info[1] if record.exc_info else None return not isinstance(exc, concurrent.futures.BrokenExecutor)
[docs]def _install_shutdown_noise_filter() -> None: """Attach :class:`_ShutdownNoiseFilter` to the asyncio logger once. asyncio's default exception handler logs through the ``asyncio`` logger, so the filter is installed there (a logger-level filter drops the record before it propagates to the root handlers). Idempotent: a second runner does not stack a duplicate filter. """ aio_logger = logging.getLogger("asyncio") if not any(isinstance(filt, _ShutdownNoiseFilter) for filt in aio_logger.filters): aio_logger.addFilter(_ShutdownNoiseFilter())
# Process-local engine pool for the current worker. Populated by # worker_initializer when the ProcessPool spins up each worker process and # read back by the run_md bridge task. It is deliberately process-local # (each worker imports this module fresh) so concurrent workers never share # engine instances -- the pool is no longer a global of the core sampling # module pyretis.core._tis_inf. _WORKER_ENGINES: Dict[str, Any] = {}
[docs]def get_worker_engines() -> Dict[str, Any]: """Return the engine pool created for the current worker process.""" return _WORKER_ENGINES
[docs]def set_worker_engines(engines: Dict[str, Any]) -> None: """Install the engine pool for the current worker process.""" _WORKER_ENGINES.clear() _WORKER_ENGINES.update(engines)
[docs]class RunnerError(Exception): """Exception class for the runner.""" pass
[docs]class aiorunner: """A light asynchronuous runner based on asyncio. The runner manage an asyncio.queue with a pool of workers. Upon instanciation, a dedicated event loop is launched in a separate thread. The user can then attach a worker function to the runner and start multiple instances of that function in the background. As work is submitted to the runner, it is picked up by workers on-the-fly. """
[docs] def __init__(self, config: Dict, n_workers: int = 1) -> None: """Init function of runner. Args: config: The simulation configuration dictionary. It is forwarded **unchanged** to every worker process via the pool initializer (:func:`worker_initializer`) and must therefore be picklable (the pool uses the ``spawn`` start method). When ``config`` contains a ``"simulation"`` section the initializer builds that worker's engine pool and order parameters from it (via :func:`pyretis.engines.factory.create_engines` and :func:`create_orderparameters`) and installs the engines as process-local state with :func:`set_worker_engines`; the engines are intentionally *not* part of the per-task work units, so they never cross the process boundary. n_workers: Number of worker processes in the pool. """ self._n_workers: int = n_workers self._counter = multiprocessing.get_context("spawn").Value("i", 0) self._executor: concurrent.futures.Executor = ( concurrent.futures.ProcessPoolExecutor( max_workers=n_workers, initializer=worker_initializer, initargs=(self._counter, config), mp_context=multiprocessing.get_context("spawn"), ) ) self._stop_event = asyncio.Event() self._loop = asyncio.new_event_loop() _install_shutdown_noise_filter() self._thread = threading.Thread( target=self._start_event_loop, daemon=True ) self._thread.start() self._queue: asyncio.Queue[Any] = asyncio.Queue() self._task_f: Optional[Callable] = None self._tasks: Optional[List[asyncio.Task[Any]]] = None
[docs] def start(self) -> None: """Launch background tasks.""" future = asyncio.run_coroutine_threadsafe( self._start_tasks(), self._loop ) try: # Task startup should be fast future.result(5.0) except TimeoutError: raise RunnerError("Launching background tasks took too long") except Exception as e: raise (e)
[docs] def _start_event_loop(self) -> None: """Start the event loop in a separate thread.""" asyncio.set_event_loop(self._loop) self._loop.run_forever()
[docs] def set_task(self, task_f: Callable) -> None: """Attach the task function to the runner. Args: task_f: a callable function """ self._task_f = task_f
[docs] async def _task_wrapper( self, stop_event: asyncio.Event, queue: asyncio.Queue, executor: concurrent.futures.Executor, taskID: int, ) -> None: """Wrap the sync task. To enable running the sync task_f from a dynamic list of tasks. Args: stop_event: a asyncio event to stop the worker queue: an asyncio queue to get work from executor: an executor taskID : an ID for the long running task """ while not stop_event.is_set(): try: # Unpack queue element md_item, future = queue.get_nowait() # Run the task in the event loop if self._task_f is None: raise RuntimeError("worker has no task function set") loop = asyncio.get_running_loop() try: md_item = await loop.run_in_executor( executor, functools.partial(self._task_f, md_item) ) future.set_result(md_item) except Exception as e: # Pass the exception up in the future future.set_exception(e) # Mask the task as done queue.task_done() except asyncio.QueueEmpty: await asyncio.sleep(0.02)
[docs] async def _add_work_to_queue( self, work_unit: Dict[str, Any] ) -> asyncio.Future: """Async function adding work to queue, returns a future. Args work_unit: a unit of work encapsulated in a dict Return: A future wih the results of the work """ future: asyncio.Future = asyncio.Future() await self._queue.put((work_unit, future)) return future
[docs] def submit_work(self, work_unit: Dict[str, Any]) -> Future: """Submit work to the runner. Args: task: a unit of work encapsulated in a dict Return: A future wih the results of the work """ if not self._tasks: raise RunnerError( "Unable to submit work if the tasks haven't been initiated" ) future = asyncio.run(self._add_work_to_queue(work_unit)) # Need to wait otherwise some race condition can occur time.sleep(0.05) return future
[docs] async def _start_tasks(self) -> None: """Launch the background tasks.""" if not self._task_f: raise RunnerError("Can't start task(s) without a task function.") try: self._tasks = [ asyncio.create_task( self._task_wrapper( self._stop_event, self._queue, self._executor, i ) ) for i in range(self._n_workers) ] except Exception as e: raise e
[docs] async def wait_for_tasks_to_end(self) -> None: """Async function waiting for tasks to end.""" while len(asyncio.all_tasks(self._loop)) > 0: await asyncio.sleep(0.1)
[docs] def n_workers(self) -> int: """Return runner number of workers.""" return self._n_workers
[docs] def stop(self) -> None: """Terminate the runner and release the worker pool.""" # Make sure there is no more work in the queue # before dispatching the task stopping event while self._queue.qsize() > 0: time.sleep(0.1) # Stop ongoing tasks self._stop_event.set() # Wait until all tasks are done asyncio.run(self.wait_for_tasks_to_end()) # Close the event loop self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() # Release the worker pool deterministically. ProcessPoolExecutor # otherwise only reaps its workers via its interpreter-exit # (atexit) handler, which does not run on SIGTERM. self._executor.shutdown(wait=True)
[docs] async def _cancel_pending_tasks(self) -> None: """Cancel the worker-wrapper tasks and await their completion.""" tasks = [t for t in (self._tasks or []) if not t.done()] for task in tasks: task.cancel() if tasks: await asyncio.gather(*tasks, return_exceptions=True)
[docs] def close(self) -> None: """Force-release runner resources without draining the queue. Safe to call from an error/interrupt path (unlike :meth:`stop`, which waits for the work queue to empty): it cancels the pending worker tasks, stops the event loop and shuts the pool down so it is never orphaned. """ self._stop_event.set() if self._loop.is_running(): future = asyncio.run_coroutine_threadsafe( self._cancel_pending_tasks(), self._loop ) try: future.result(timeout=5) except Exception: # pylint: disable=broad-exception-caught # Best-effort cleanup: never let shutdown hang on a stuck # task; the pool is released unconditionally below. logger.debug("shutdown: pending tasks did not finish in time") self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join(timeout=5) self._executor.shutdown(wait=False, cancel_futures=True)
[docs]def prepare_streaming_engines(engines, config): """Give the internal integrators their file-backed streaming setup. The coordinator hands every engine file-backed phase points (a snapshot ``System`` whose ``config`` points at a trajectory file). The external engines read that file natively; the internal integrators (``langevin`` / ``velocityverlet`` / ``verlet`` / ``randomwalk``) integrate from in-memory particles instead, so they need a streaming template -- their own box / particles / masses / force field -- built from the carried-through native ``[system]`` / ``[box]`` / ``[particles]`` / ``[potential]`` / ``[forcefield]`` sections. This mirrors how :py:class:`.TurtleMDEngine` builds its own box / particles / potential in ``__init__``. Engines without a ``setup_streaming`` method (the external engines) are skipped. Parameters ---------- engines : dict of lists The per-worker engine pool, keyed by engine name. config : dict The coordinator configuration dictionary. """ for engine_key in engines: for engine in engines[engine_key]: setup = getattr(engine, "setup_streaming", None) if callable(setup): setup(config)
[docs]def worker_initializer(counter, config): """Initialize function for each worker process.""" # load engines if infretis simulation if "simulation" in config: engines, _ = create_engines(config) create_orderparameters(engines, config) prepare_streaming_engines(engines, config) # Install this worker's engine pool as process-local state, read # back by the run_md bridge task (no module global in _tis_inf). set_worker_engines(engines) with counter.get_lock(): # Ensure that counter increment is thread-safe worker_id = counter.value counter.value += 1 fileh = logging.FileHandler(f"worker{worker_id}.log", mode="a") log_levl = getattr(logging, "info".upper(), logging.INFO) fileh.setLevel(log_levl) fileh.setFormatter(get_log_formatter(log_levl)) logger.addHandler(fileh) logger.info("=============================") logger.info("Logging file for worker %s", worker_id) logger.info("=============================\n")
[docs]class future_list: """A managed list of future."""
[docs] def __init__(self) -> None: """Initialize future list.""" self._futures: List[asyncio.Future] = []
[docs] def add(self, future: asyncio.Future) -> None: """Add a future to list.""" self._futures.append(future)
[docs] def as_completed(self) -> Optional[asyncio.Future]: """Get future as they are done. Return: return a future from the list, whenever it is done or return None when the list is empty. """ future_out = None while len(self._futures) > 0 and not future_out: for fut in list(self._futures): if fut.done(): future_out = fut self._futures.remove(fut) break return future_out