class Computation: """ A base class for computations, intended to be subclassed. :param root_parameters: a list of :py:class:`~reikna.core.Parameter` objects. .. py:attribute:: signature A :py:class:`~reikna.core.Signature` object representing current computation signature (taking into account connected transformations). .. py:attribute:: parameter A named tuple of :py:class:`~reikna.core.computation.ComputationParameter` objects corresponding to parameters from the current :py:attr:`signature`. """ def __init__(self, root_parameters): for param in root_parameters: check_external_parameter_name(param.name) self._tr_tree = TransformationTree(root_parameters) self._update_attributes() def _update_attributes(self): """ Updates ``signature`` and ``parameter`` attributes. Called by the methods that change the signature. """ leaf_params = self._tr_tree.get_leaf_parameters() self.signature = Signature(leaf_params) self.parameter = make_parameter_container(self, leaf_params) # The names are underscored to avoid name conflicts with ``tr_from_comp`` keys # (where the user can introduce new parameter names) def connect(self, _comp_connector, _trf, _tr_connector, **tr_from_comp): """ Connect a transformation to the computation. :param _comp_connector: connection target --- a :py:class:`~reikna.core.computation.ComputationParameter` object beloning to this computation object, or a string with its name. :param _trf: a :py:class:`~reikna.core.Transformation` object. :param _tr_connector: connector on the side of the transformation --- a :py:class:`~reikna.core.transformation.TransformationParameter` object beloning to ``tr``, or a string with its name. :param tr_from_comp: a dictionary with the names of new or old computation parameters as keys, and :py:class:`~reikna.core.transformation.TransformationParameter` objects (or their names) as values. :returns: this computation object (modified). .. note:: The resulting parameter order is determined by traversing the graph of connections depth-first (starting from the initial computation parameters), with the additional condition: the nodes do not change their order in the same branching level (i.e. in the list of computation or transformation parameters, both of which are ordered). For example, consider a computation with parameters ``(a, b, c, d)``. If you connect a transformation ``(a', c) -> a``, the resulting computation will have the signature ``(a', b, c, d)`` (as opposed to ``(a', c, b, d)`` it would have for the pure depth-first traversal). """ # Extract connector name if isinstance(_comp_connector, ComputationParameter): if not _comp_connector.belongs_to(self): raise ValueError("The connection target must belong to this computation.") param_name = str(_comp_connector) # Extract transformation parameters names if param_name in tr_from_comp: raise ValueError( "Parameter '" + param_name + "' cannot be supplied " + "both as the main connector and one of the child connections") tr_from_comp[param_name] = _tr_connector comp_from_tr = {} for comp_connection_name, tr_connection in tr_from_comp.items(): check_external_parameter_name(comp_connection_name) if isinstance(tr_connection, TransformationParameter): if not tr_connection.belongs_to(_trf): raise ValueError( "The transformation parameter must belong to the provided transformation") tr_connection_name = str(tr_connection) comp_from_tr[tr_connection_name] = comp_connection_name self._tr_tree.connect(param_name, _trf, comp_from_tr) self._update_attributes() return self def _translate_tree(self, translator): return self._tr_tree.translate(translator) def _get_plan(self, tr_tree, translator, thread, fast_math): plan_factory = lambda: ComputationPlan(tr_tree, translator, thread, fast_math) args = [ KernelArgument(param.name, param.annotation.type) for param in tr_tree.get_root_parameters()] return self._build_plan(plan_factory, thread.device_params, *args) def compile(self, thread, fast_math=False): """ Compiles the computation with the given :py:class:`~reikna.cluda.api.Thread` object and returns a :py:class:`~reikna.core.computation.ComputationCallable` object. If ``fast_math`` is enabled, the compilation of all kernels is performed using the compiler options for fast and imprecise mathematical functions. """ translator = Translator.identity() return self._get_plan(self._tr_tree, translator, thread, fast_math).finalize() def _build_plan(self, plan_factory, device_params, *args): """ Derived classes override this method. It is called by :py:meth:`compile` and supposed to return a :py:class:`~reikna.core.computation.ComputationPlan` object. :param plan_factory: a callable returning a new :py:class:`~reikna.core.computation.ComputationPlan` object. :param device_params: a :py:class:`~reikna.cluda.api.DeviceParameters` object corresponding to the thread the computation is being compiled for. :param args: :py:class:`~reikna.core.computation.KernelArgument` objects, corresponding to ``parameters`` specified during the creation of this computation object. """ raise NotImplementedError
class Computation: """ A base class for computations, intended to be subclassed. :param root_parameters: a list of :py:class:`~reikna.core.Parameter` objects. .. py:attribute:: signature A :py:class:`~reikna.core.Signature` object representing current computation signature (taking into account connected transformations). .. py:attribute:: parameter A named tuple of :py:class:`~reikna.core.computation.ComputationParameter` objects corresponding to parameters from the current :py:attr:`signature`. """ def __init__(self, root_parameters): for param in root_parameters: check_external_parameter_name(param.name) self._tr_tree = TransformationTree(root_parameters) self._update_attributes() def _update_attributes(self): """ Updates ``signature`` and ``parameter`` attributes. Called by the methods that change the signature. """ leaf_params = self._tr_tree.get_leaf_parameters() self.signature = Signature(leaf_params) self.parameter = make_parameter_container(self, leaf_params) # The names are underscored to avoid name conflicts with ``tr_from_comp`` keys # (where the user can introduce new parameter names) def connect(self, _comp_connector, _trf, _tr_connector, **tr_from_comp): """ Connect a transformation to the computation. :param _comp_connector: connection target --- a :py:class:`~reikna.core.computation.ComputationParameter` object beloning to this computation object, or a string with its name. :param _trf: a :py:class:`~reikna.core.Transformation` object. :param _tr_connector: connector on the side of the transformation --- a :py:class:`~reikna.core.transformation.TransformationParameter` object beloning to ``tr``, or a string with its name. :param tr_from_comp: a dictionary with the names of new or old computation parameters as keys, and :py:class:`~reikna.core.transformation.TransformationParameter` objects (or their names) as values. :returns: this computation object (modified). .. note:: The resulting parameter order is determined by traversing the graph of connections depth-first (starting from the initial computation parameters), with the additional condition: the nodes do not change their order in the same branching level (i.e. in the list of computation or transformation parameters, both of which are ordered). For example, consider a computation with parameters ``(a, b, c, d)``. If you connect a transformation ``(a', c) -> a``, the resulting computation will have the signature ``(a', b, c, d)`` (as opposed to ``(a', c, b, d)`` it would have for the pure depth-first traversal). """ # Extract connector name if isinstance(_comp_connector, ComputationParameter): if not _comp_connector.belongs_to(self): raise ValueError( "The connection target must belong to this computation.") param_name = str(_comp_connector) # Extract transformation parameters names if param_name in tr_from_comp: raise ValueError( "Parameter '" + param_name + "' cannot be supplied " + "both as the main connector and one of the child connections") tr_from_comp[param_name] = _tr_connector comp_from_tr = {} for comp_connection_name, tr_connection in tr_from_comp.items(): check_external_parameter_name(comp_connection_name) if isinstance(tr_connection, TransformationParameter): if not tr_connection.belongs_to(_trf): raise ValueError( "The transformation parameter must belong to the provided transformation" ) tr_connection_name = str(tr_connection) comp_from_tr[tr_connection_name] = comp_connection_name self._tr_tree.connect(param_name, _trf, comp_from_tr) self._update_attributes() return self def _translate_tree(self, translator): return self._tr_tree.translate(translator) def _get_plan(self, tr_tree, translator, thread, fast_math): plan_factory = lambda: ComputationPlan(tr_tree, translator, thread, fast_math) args = [ KernelArgument(param.name, param.annotation.type) for param in tr_tree.get_root_parameters() ] return self._build_plan(plan_factory, thread.device_params, *args) def compile(self, thread, fast_math=False): """ Compiles the computation with the given :py:class:`~reikna.cluda.api.Thread` object and returns a :py:class:`~reikna.core.computation.ComputationCallable` object. If ``fast_math`` is enabled, the compilation of all kernels is performed using the compiler options for fast and imprecise mathematical functions. """ translator = Translator.identity() return self._get_plan(self._tr_tree, translator, thread, fast_math).finalize() def _build_plan(self, plan_factory, device_params, *args): """ Derived classes override this method. It is called by :py:meth:`compile` and supposed to return a :py:class:`~reikna.core.computation.ComputationPlan` object. :param plan_factory: a callable returning a new :py:class:`~reikna.core.computation.ComputationPlan` object. :param device_params: a :py:class:`~reikna.cluda.api.DeviceParameters` object corresponding to the thread the computation is being compiled for. :param args: :py:class:`~reikna.core.computation.KernelArgument` objects, corresponding to ``parameters`` specified during the creation of this computation object. """ raise NotImplementedError
def __init__(self, root_parameters): for param in root_parameters: check_external_parameter_name(param.name) self._tr_tree = TransformationTree(root_parameters) self._update_attributes()