Exemplo n.º 1
0
class Computation:
    """
    A base class for computations, intended to be subclassed.

    :param root_parameters: a list of :py:class:`~reikna.core.Parameter` objects.

    .. py:attribute:: signature

        A :py:class:`~reikna.core.Signature` object representing current computation signature
        (taking into account connected transformations).

    .. py:attribute:: parameter

        A named tuple of :py:class:`~reikna.core.computation.ComputationParameter` objects
        corresponding to parameters from the current :py:attr:`signature`.
    """

    def __init__(self, root_parameters):
        for param in root_parameters:
            check_external_parameter_name(param.name)
        self._tr_tree = TransformationTree(root_parameters)
        self._update_attributes()

    def _update_attributes(self):
        """
        Updates ``signature`` and ``parameter`` attributes.
        Called by the methods that change the signature.
        """
        leaf_params = self._tr_tree.get_leaf_parameters()
        self.signature = Signature(leaf_params)
        self.parameter = make_parameter_container(self, leaf_params)

    # The names are underscored to avoid name conflicts with ``tr_from_comp`` keys
    # (where the user can introduce new parameter names)
    def connect(self, _comp_connector, _trf, _tr_connector, **tr_from_comp):
        """
        Connect a transformation to the computation.

        :param _comp_connector: connection target ---
            a :py:class:`~reikna.core.computation.ComputationParameter` object
            beloning to this computation object, or a string with its name.
        :param _trf: a :py:class:`~reikna.core.Transformation` object.
        :param _tr_connector: connector on the side of the transformation ---
            a :py:class:`~reikna.core.transformation.TransformationParameter` object
            beloning to ``tr``, or a string with its name.
        :param tr_from_comp: a dictionary with the names of new or old
            computation parameters as keys, and
            :py:class:`~reikna.core.transformation.TransformationParameter` objects
            (or their names) as values.
        :returns: this computation object (modified).

        .. note::

            The resulting parameter order is determined by traversing
            the graph of connections depth-first (starting from the initial computation parameters),
            with the additional condition: the nodes do not change their order
            in the same branching level (i.e. in the list of computation or
            transformation parameters, both of which are ordered).

            For example, consider a computation with parameters ``(a, b, c, d)``.
            If you connect a transformation ``(a', c) -> a``, the resulting computation
            will have the signature ``(a', b, c, d)`` (as opposed to ``(a', c, b, d)``
            it would have for the pure depth-first traversal).
        """

        # Extract connector name
        if isinstance(_comp_connector, ComputationParameter):
            if not _comp_connector.belongs_to(self):
                raise ValueError("The connection target must belong to this computation.")
        param_name = str(_comp_connector)

        # Extract transformation parameters names

        if param_name in tr_from_comp:
            raise ValueError(
                "Parameter '" + param_name + "' cannot be supplied " +
                "both as the main connector and one of the child connections")

        tr_from_comp[param_name] = _tr_connector
        comp_from_tr = {}
        for comp_connection_name, tr_connection in tr_from_comp.items():
            check_external_parameter_name(comp_connection_name)
            if isinstance(tr_connection, TransformationParameter):
                if not tr_connection.belongs_to(_trf):
                    raise ValueError(
                        "The transformation parameter must belong to the provided transformation")
            tr_connection_name = str(tr_connection)
            comp_from_tr[tr_connection_name] = comp_connection_name

        self._tr_tree.connect(param_name, _trf, comp_from_tr)
        self._update_attributes()
        return self

    def _translate_tree(self, translator):
        return self._tr_tree.translate(translator)

    def _get_plan(self, tr_tree, translator, thread, fast_math):
        plan_factory = lambda: ComputationPlan(tr_tree, translator, thread, fast_math)
        args = [
            KernelArgument(param.name, param.annotation.type)
            for param in tr_tree.get_root_parameters()]
        return self._build_plan(plan_factory, thread.device_params, *args)

    def compile(self, thread, fast_math=False):
        """
        Compiles the computation with the given :py:class:`~reikna.cluda.api.Thread` object
        and returns a :py:class:`~reikna.core.computation.ComputationCallable` object.
        If ``fast_math`` is enabled, the compilation of all kernels is performed using
        the compiler options for fast and imprecise mathematical functions.
        """
        translator = Translator.identity()
        return self._get_plan(self._tr_tree, translator, thread, fast_math).finalize()

    def _build_plan(self, plan_factory, device_params, *args):
        """
        Derived classes override this method.
        It is called by :py:meth:`compile` and
        supposed to return a :py:class:`~reikna.core.computation.ComputationPlan` object.

        :param plan_factory: a callable returning a new
            :py:class:`~reikna.core.computation.ComputationPlan` object.
        :param device_params: a :py:class:`~reikna.cluda.api.DeviceParameters` object corresponding
            to the thread the computation is being compiled for.
        :param args: :py:class:`~reikna.core.computation.KernelArgument` objects,
            corresponding to ``parameters`` specified during the creation
            of this computation object.
        """
        raise NotImplementedError
Exemplo n.º 2
0
class Computation:
    """
    A base class for computations, intended to be subclassed.

    :param root_parameters: a list of :py:class:`~reikna.core.Parameter` objects.

    .. py:attribute:: signature

        A :py:class:`~reikna.core.Signature` object representing current computation signature
        (taking into account connected transformations).

    .. py:attribute:: parameter

        A named tuple of :py:class:`~reikna.core.computation.ComputationParameter` objects
        corresponding to parameters from the current :py:attr:`signature`.
    """
    def __init__(self, root_parameters):
        for param in root_parameters:
            check_external_parameter_name(param.name)
        self._tr_tree = TransformationTree(root_parameters)
        self._update_attributes()

    def _update_attributes(self):
        """
        Updates ``signature`` and ``parameter`` attributes.
        Called by the methods that change the signature.
        """
        leaf_params = self._tr_tree.get_leaf_parameters()
        self.signature = Signature(leaf_params)
        self.parameter = make_parameter_container(self, leaf_params)

    # The names are underscored to avoid name conflicts with ``tr_from_comp`` keys
    # (where the user can introduce new parameter names)
    def connect(self, _comp_connector, _trf, _tr_connector, **tr_from_comp):
        """
        Connect a transformation to the computation.

        :param _comp_connector: connection target ---
            a :py:class:`~reikna.core.computation.ComputationParameter` object
            beloning to this computation object, or a string with its name.
        :param _trf: a :py:class:`~reikna.core.Transformation` object.
        :param _tr_connector: connector on the side of the transformation ---
            a :py:class:`~reikna.core.transformation.TransformationParameter` object
            beloning to ``tr``, or a string with its name.
        :param tr_from_comp: a dictionary with the names of new or old
            computation parameters as keys, and
            :py:class:`~reikna.core.transformation.TransformationParameter` objects
            (or their names) as values.
        :returns: this computation object (modified).

        .. note::

            The resulting parameter order is determined by traversing
            the graph of connections depth-first (starting from the initial computation parameters),
            with the additional condition: the nodes do not change their order
            in the same branching level (i.e. in the list of computation or
            transformation parameters, both of which are ordered).

            For example, consider a computation with parameters ``(a, b, c, d)``.
            If you connect a transformation ``(a', c) -> a``, the resulting computation
            will have the signature ``(a', b, c, d)`` (as opposed to ``(a', c, b, d)``
            it would have for the pure depth-first traversal).
        """

        # Extract connector name
        if isinstance(_comp_connector, ComputationParameter):
            if not _comp_connector.belongs_to(self):
                raise ValueError(
                    "The connection target must belong to this computation.")
        param_name = str(_comp_connector)

        # Extract transformation parameters names

        if param_name in tr_from_comp:
            raise ValueError(
                "Parameter '" + param_name + "' cannot be supplied " +
                "both as the main connector and one of the child connections")

        tr_from_comp[param_name] = _tr_connector
        comp_from_tr = {}
        for comp_connection_name, tr_connection in tr_from_comp.items():
            check_external_parameter_name(comp_connection_name)
            if isinstance(tr_connection, TransformationParameter):
                if not tr_connection.belongs_to(_trf):
                    raise ValueError(
                        "The transformation parameter must belong to the provided transformation"
                    )
            tr_connection_name = str(tr_connection)
            comp_from_tr[tr_connection_name] = comp_connection_name

        self._tr_tree.connect(param_name, _trf, comp_from_tr)
        self._update_attributes()
        return self

    def _translate_tree(self, translator):
        return self._tr_tree.translate(translator)

    def _get_plan(self, tr_tree, translator, thread, fast_math):
        plan_factory = lambda: ComputationPlan(tr_tree, translator, thread,
                                               fast_math)
        args = [
            KernelArgument(param.name, param.annotation.type)
            for param in tr_tree.get_root_parameters()
        ]
        return self._build_plan(plan_factory, thread.device_params, *args)

    def compile(self, thread, fast_math=False):
        """
        Compiles the computation with the given :py:class:`~reikna.cluda.api.Thread` object
        and returns a :py:class:`~reikna.core.computation.ComputationCallable` object.
        If ``fast_math`` is enabled, the compilation of all kernels is performed using
        the compiler options for fast and imprecise mathematical functions.
        """
        translator = Translator.identity()
        return self._get_plan(self._tr_tree, translator, thread,
                              fast_math).finalize()

    def _build_plan(self, plan_factory, device_params, *args):
        """
        Derived classes override this method.
        It is called by :py:meth:`compile` and
        supposed to return a :py:class:`~reikna.core.computation.ComputationPlan` object.

        :param plan_factory: a callable returning a new
            :py:class:`~reikna.core.computation.ComputationPlan` object.
        :param device_params: a :py:class:`~reikna.cluda.api.DeviceParameters` object corresponding
            to the thread the computation is being compiled for.
        :param args: :py:class:`~reikna.core.computation.KernelArgument` objects,
            corresponding to ``parameters`` specified during the creation
            of this computation object.
        """
        raise NotImplementedError
Exemplo n.º 3
0
 def __init__(self, root_parameters):
     for param in root_parameters:
         check_external_parameter_name(param.name)
     self._tr_tree = TransformationTree(root_parameters)
     self._update_attributes()
Exemplo n.º 4
0
 def __init__(self, root_parameters):
     for param in root_parameters:
         check_external_parameter_name(param.name)
     self._tr_tree = TransformationTree(root_parameters)
     self._update_attributes()