# -*- coding: utf-8 -*-
# Copyright (c) 2019, PyRETIS Development Team.
# Distributed under the LGPLv2.1+ License. See LICENSE for more info.
"""Definition of some common methods that might be useful.
Important methods defined here
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inspect_function (:py:func:`.inspect_function`)
A method to obtain information about arguments, keyword arguments
for functions.
initiate_instance (:py:func:`.initiate_instance`)
Method to initiate a class with optional arguments.
generic_factory (:py:func:`.generic_factory`)
Create instances of classes based on settings.
compare_objects (:py:func`.compare_objects`)
Method to compare two PyRETIS objects.
crossing_counter (:py:func`.crossing_counter`)
Function to count the crossing of a path on an interface.
crossing_finder (:py:func`.crossing_finder`)
Function to get the shooting points of the crossing of a path
on an interface.
select_and_trim_a_segment (:py:func`.select_and_trim_a_segment`)
Function to trim a path between interfaces plus the two external points.
trim_path_between_interfaces (:py:func`.trim_path_between_interfaces`)
Function to trim a path between interfaces.
"""
import logging
import inspect
import numpy as np
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger.addHandler(logging.NullHandler())
__all__ = ['inspect_function', 'initiate_instance', 'generic_factory',
'crossing_counter', 'crossing_finder',
'select_and_trim_a_segment', 'trim_path_between_interfaces',
'big_fat_comparer']
def _arg_kind(arg):
"""Determine kind for a given argument.
This method will help :py:func:`.inspect_function` to determine
the correct kind for arguments.
Parameters
----------
arg : object like :py:class:`inspect.Parameter`
The argument we will determine the type of.
Returns
-------
out : string
A string we use for determine the kind.
"""
kind = None
if arg.kind == arg.POSITIONAL_OR_KEYWORD:
if arg.default is arg.empty:
kind = 'args'
else:
kind = 'kwargs'
elif arg.kind == arg.POSITIONAL_ONLY:
kind = 'args'
elif arg.kind == arg.VAR_POSITIONAL:
kind = 'varargs'
elif arg.kind == arg.VAR_KEYWORD:
kind = 'keywords'
elif arg.kind == arg.KEYWORD_ONLY:
# We treat these as keyword arguments:
kind = 'kwargs'
return kind
[docs]def big_fat_comparer(any1, any2, hard=False):
"""Check if two dictionary are the same, regardless their complexity.
Parameters
----------
any1 : anything
any2 : anything
hard : boolean, optional
Raise ValueError if any1 and any2 are different
Returns
-------
out : boolean
True if any1 = any2, false otherwise
"""
if type(any1) is not type(any2):
if hard:
raise ValueError('Fail type', any1, any2)
return False
if isinstance(any1, (list, tuple)):
if len(any1) != len(any2):
if hard:
raise ValueError('Fail list length', any1, any2)
return False
for key1, key2 in zip(any1, any2):
if not big_fat_comparer(key1, key2, hard):
if hard:
raise ValueError('Fail item in list',
any1, any2) # pragma: no cover
return False
elif isinstance(any1, np.ndarray):
if any1.shape != any2.shape:
if hard:
raise ValueError('Fail np array shape', any1, any2)
return False
for key1, key2 in zip(np.nditer(any1), np.nditer(any2)):
if not (key1 == key2).all():
if hard:
raise ValueError('Fail np array item', any1, any2)
return False
elif isinstance(any1, dict):
for key in any1:
if key not in any2:
if hard:
raise ValueError('Fail dict', any1, any2)
return False
if not isinstance(any1[key], type(any2[key])):
if hard:
raise ValueError('Fail types', any1[key], any2[key])
return False
if isinstance(any1[key], (dict, list, tuple, np.ndarray)):
if not big_fat_comparer(any1[key], any2[key], hard):
if hard:
raise ValueError('Fail item',
any1[key],
any2[key]) # pragma: no cover
return False
else:
if any1[key] != any2[key]:
if hard:
raise ValueError('Fail item', any1[key], any2[key])
return False
for key in any2:
if key not in any1:
if hard:
raise ValueError('Fail item', any1, any2)
return False
else:
if any1 != any2:
if hard:
raise ValueError('Fail item', any1, any2)
return False
return True
[docs]def inspect_function(function):
"""Return arguments/kwargs of a given function.
This method is intended for use where we are checking that we can
call a certain function. This method will return arguments and
keyword arguments a function expects. This method may be fragile -
we assume here that we are not really interested in args and
kwargs and we do not look for more information about these here.
Parameters
----------
function : callable
The function to inspect.
Returns
-------
out : dict
A dict with the arguments, the following keys are defined:
* `args` : list of the positional arguments
* `kwargs` : list of keyword arguments
* `varargs` : list of arguments
* `keywords` : list of keyword arguments
"""
out = {'args': [], 'kwargs': [],
'varargs': [], 'keywords': []}
arguments = inspect.signature(function) # pylint: disable=no-member
for arg in arguments.parameters.values():
kind = _arg_kind(arg)
if kind is not None:
out[kind].append(arg.name)
else: # pragma: no cover
logger.critical('Unknown variable kind "%s" for "%s"',
arg.kind, arg.name)
return out
def _pick_out_arg_kwargs(klass, settings):
"""Pick out arguments for a class from settings.
Parameters
----------
klass : class
The class to initiate.
settings : dict
Positional and keyword arguments to pass to `klass.__init__()`.
Returns
-------
out[0] : list
A list of the positional arguments.
out[1] : dict
The keyword arguments.
"""
info = inspect_function(klass.__init__)
used, args, kwargs = set(), [], {}
for arg in info['args']:
if arg == 'self':
continue
try:
args.append(settings[arg])
used.add(arg)
except KeyError:
msg = 'Required argument "{}" for "{}" not found!'.format(arg,
klass)
raise ValueError(msg)
for arg in info['kwargs']:
if arg == 'self':
continue
if arg in settings:
kwargs[arg] = settings[arg]
return args, kwargs
[docs]def initiate_instance(klass, settings):
"""Initialise a class with optional arguments.
Parameters
----------
klass : class
The class to initiate.
settings : dict
Positional and keyword arguments to pass to `klass.__init__()`.
Returns
-------
out : instance of `klass`
Here, we just return the initiated instance of the given class.
"""
args, kwargs = _pick_out_arg_kwargs(klass, settings)
# Ready to initiate:
msg = 'Initiated "%s" from "%s" %s'
name = klass.__name__
mod = klass.__module__
if not args:
if not kwargs:
logger.debug(msg, name, mod, 'without arguments.')
return klass()
logger.debug(msg, name, mod, 'with keyword arguments.')
return klass(**kwargs)
if not kwargs:
logger.debug(msg, name, mod, 'with positional arguments.')
return klass(*args)
logger.debug(msg, name, mod,
'with positional and keyword arguments.')
return klass(*args, **kwargs)
[docs]def generic_factory(settings, object_map, name='generic'):
"""Create instances of classes based on settings.
This method is intended as a semi-generic factory for creating
instances of different objects based on simulation input settings.
The input settings define what classes should be created and
the object_map defines a mapping between settings and the
class.
Parameters
----------
settings : dict
This defines how we set up and select the order parameter.
object_map : dict
Definitions on how to initiate the different classes.
name : string, optional
Short name for the object type. Only used for error messages.
Returns
-------
out : instance of a class
The created object, in case we were successful. Otherwise we
return none.
"""
try:
klass = settings['class'].lower()
except KeyError:
msg = 'No class given for %s -- could not create object!'
logger.critical(msg, name)
return None
if klass not in object_map:
logger.critical('Could not create unknown class "%s" for %s',
settings['class'], name)
return None
cls = object_map[klass]['cls']
return initiate_instance(cls, settings)
def numpy_allclose(val1, val2):
"""Compare two values with allclose from numpy.
Here, we allow for one, or both, of the values to be None.
Note that if val1 == val2 but are not of a type known to
numpy, the returned value will be False.
Parameters
----------
val1 : np.array
The variable in the comparison.
val2 : np.array
The second variable in the comparison.
Returns
-------
out : boolean
True if the values are equal, False otherwise.
"""
if val1 is None and val2 is None:
return True
if val1 is None and val2 is not None:
return False
if val1 is not None and val2 is None:
return False
try:
return np.allclose(val1, val2)
except TypeError:
return False
def compare_objects(obj1, obj2, attrs, numpy_attrs=None):
"""Compare two PyRETIS objects.
This method will compare two PyRETIS objects by checking
the equality of the attributes. Some of these attributes
might be numpy arrays in which case we use the
:py:function:`.numpy_allclose` defined in this module.
Parameters
----------
obj1 : object
The first object for the comparison.
obj2 : object
The second object for the comparison.
attrs : iterable of strings
The attributes to check.
numpy_attrs : iterable of strings, optional
The subset of attributes which are numpy arrays.
Returns
-------
out : boolean
True if the objects are equal, False otherwise.
"""
if not obj1.__class__ == obj2.__class__:
logger.debug(
'The classes are different %s != %s',
obj1.__class__, obj2.__class__
)
return False
if not len(obj1.__dict__) == len(obj2.__dict__):
logger.debug('Number of attributes differ.')
return False
# Compare the requested attributes:
for key in attrs:
try:
val1 = getattr(obj1, key)
val2 = getattr(obj2, key)
except AttributeError:
logger.debug('Failed to compare attribute "%s"', key)
return False
if numpy_attrs and key in numpy_attrs:
if not numpy_allclose(val1, val2):
logger.debug('Attribute "%s" differ.', key)
return False
else:
if not val1 == val2:
logger.debug('Attribute "%s" differ.', key)
return False
return True
def segments_counter(path, interface_l, interface_r):
"""Count the directional segment between interfaces.
Method to count the number of the directional segments of the path,
along the orderp, that connect FROM interface_l TO interface_r.
Parameters
-----------
path : object like :py:class:`.PathBase`
This is the input path which segments will be counted.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
Returns
-------
n_segments : integer
Segment counter
"""
icros, n_segments = -1, 0
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if op2 > interface_l >= op1:
icros = i
if op2 > interface_r >= op1:
if icros != -1:
icros = -1
n_segments += 1
return n_segments
[docs]def crossing_counter(path, interface):
"""Count the crossing to an interfaces.
Method to count the crosses of a path over an interface.
Parameters
-----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface : float
This is the position of the interface.
Returns
-------
cnt : integer
Number of crossing of the given interface.
"""
cnt = 0
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if op2 >= interface > op1 or op1 >= interface > op2:
cnt += 1
return cnt
[docs]def crossing_finder(path, interface):
"""Find the crossing to an interfaces.
Method to select the crosses of a path over an interface.
Parameters
-----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface : float
This is the position of the interface.
Returns
-------
ph1, ph2 : lists of snapshots
It is a list of snapshots to define the crossing,
one right before and one right after the interface.
"""
ph1, ph2 = [], []
for i in range(len(path.phasepoints[:-1])):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
if op2 >= interface > op1 or op1 >= interface > op2:
ph1.append(path.phasepoints[i])
ph2.append(path.phasepoints[i+1])
return ph1, ph2
[docs]def trim_path_between_interfaces(path, interface_l, interface_r):
"""Cut a path between the two interfaces.
The method cut a path and keeps only what is within the range
(interface_l interface_r).
-Be careful, it can provide multiple discontinuous segments-
=Be carefull2 consider if you need to make this check left inclusive
(as the ensemble should be left inclusive)
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
Returns
-------
new_path : object like :py:class:`.PathBase`
This is the output trimmed path.
"""
new_path = path.empty_path()
for phasepoint in path.phasepoints:
orderp = phasepoint.order[0]
if interface_r > orderp > interface_l:
new_path.append(phasepoint)
new_path.maxlen = path.maxlen
new_path.status = path.status
new_path.time_origin = path.time_origin
new_path.generated = 'ct'
new_path.rgen = path.rgen
return new_path
[docs]def select_and_trim_a_segment(path, interface_l, interface_r,
segment_to_pick=None):
"""Cut a directional segment from interface_l to interface_r.
It keeps what is within the range [interface_l interface_r)
AND the snapshots just after/before the interface.
Parameters
----------
path : object like :py:class:`.PathBase`
This is the input path which will be trimmed.
interface_r : float
This is the position of the RIGHT interface.
interface_l : float
This is the position of the LEFT interface.
segment_to_pick : integer (n.b. it starts from 0)
This is the segment to be selected, None = random
Returns
-------
segment : a path segment composed only the snapshots for which
orderp is between interface_r and interface_l and the
ones right after/before the interfaces
"""
key = False
segment = path.empty_path()
segment_i = -1
if segment_to_pick is None:
segment_number = segments_counter(path, interface_l, interface_r)
segment_to_pick = path.rgen.random_integers(0, segment_number)
for i, phasepoint in enumerate(path.phasepoints[:-1]):
op1 = path.phasepoints[i].order[0]
op2 = path.phasepoints[i+1].order[0]
# NB: these are directional crossing
if op2 >= interface_l > op1:
# We are in the good region, segment_i
if not key:
segment_i += 1
key = True
if key:
if segment_i == segment_to_pick:
segment.append(phasepoint)
if op2 >= interface_r > op1:
if key and segment_i == segment_to_pick:
segment.append(path.phasepoints[i+1])
key = False
segment.maxlen = path.maxlen
segment.status = path.status
segment.time_origin = path.time_origin
segment.generated = 'sg'
segment.rgen = path.rgen
return segment