Esempio n. 1
0
    def __init__(
        self,
        inputs: List[Variable],
        outputs: List[Variable],
        inline: bool = False,
        lop_overrides: str = "default",
        grad_overrides: str = "default",
        rop_overrides: str = "default",
        connection_pattern: Optional[List[List[bool]]] = None,
        name: Optional[str] = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        inputs
            The inputs to the graph.
        outputs
            The outputs to the graph.
        inline
            Defaults to ``False``

            ``True`` : Cause the :class:`Op`'s original graph being used during
            compilation, the :class:`Op` will not be visible in the compiled
            graph but rather its internal graph.

            ``False`` : will use a pre-compiled function inside.
        grad_overrides
            Defaults to ``'default'``.
            This argument is mutually exclusive with ``lop_overrides``.

            ``'default'`` : Do not override, use default grad() result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs`` and ``output_grads``
            arguments as one would specify in :meth:`Op.grad`() method.

            `callable`: Should take two args: ``inputs`` and ``output_grads``.
            Each argument is expected to be a list of :class:`Variable `.
            Must return list of :class:`Variable `.
        lop_overrides
            Defaults to ``'default'``.

            This argument is mutually exclusive with ``grad_overrides``.

            These options are similar to the ``grad_overrides`` above, but for
            the :meth:`Op.L_op` method.

            ``'default'``: Do not override, use the default :meth:`Op.L_op` result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs``,
            ``outputs`` and ``output_grads`` arguments as one would specify in
            :meth:`Op.grad` method.

            `callable`: Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
            Each argument is expected to be a list of :class:`Variable`.
            Must return list of :class:`Variable`.

            `NullType` instance: Treat as non-differentiable
            `DisconnectedType` instance: Treat as disconnected gradient,
            numerically gives zero

            ``list``: Each `OpFromGraph`/callable must return a single
            :class:`Variable`. Each list element corresponds to gradient of
            a specific input, length of list must be equal to number of inputs.

        rop_overrides
            One of ``{'default', OpFromGraph, callable, Variable}``.

            Defaults to ``'default'``.

            ``'default'``: Do not override, use the default :meth:`Op.R_op` result

            `OpFromGraph`: Override with another `OpFromGraph`, should
            accept inputs as the same order and types of ``inputs`` and ``eval_points``
            arguments as one would specify in :meth:`Op.R_op` method.

            `callable`: Should take two args: ``inputs`` and ``eval_points``.
            Each argument is expected to be a list of :class:`Variable`.  Must
            return list of :class:`Variable`.

            `NullType` instance: Treat as non-differentiable `DisconnectedType`
            instance: Treat as zero since `DisconnectedType` is not yet supported
            in :meth:`Op.R_op`.

            ``list``:
            Each :class:`OpFromGraph`/callable must return a single
            :class:`Variable <aesara.graph.basic.Variable>`. Each list element
            corresponds to a specific output of :meth:`Op.R_op`, length of list
            must be equal to number of outputs.  connection_pattern If not
            ``None``, this will be used as the connection_pattern for this
            :class:`Op`.
        name
            A name for debugging purposes.
        kwargs
            Check :func:`aesara.function` for more arguments, only works when not
            inline.
        """

        if not (isinstance(inputs, list) and isinstance(outputs, list)):
            raise TypeError("Inputs and outputs must be lists")

        for i in inputs + outputs:
            if not isinstance(i, Variable):
                raise TypeError(
                    f"Inputs and outputs must be Variable instances; got {i}")
            if i in inputs:
                if isinstance(i, Constant):
                    raise TypeError(f"Constants not allowed as inputs; {i}")
                if isinstance(i, SharedVariable):
                    raise TypeError(
                        f"SharedVariables not allowed as inputs; {i}")

        for var in graph_inputs(outputs, inputs):
            if var not in inputs and not isinstance(
                    var, (Constant, SharedVariable)):
                raise MissingInputError(
                    f"OpFromGraph is missing an input: {var}")

        if "updates" in kwargs or "givens" in kwargs:
            raise NotImplementedError(
                "Updates and givens are not allowed here")

        self.is_inline = inline

        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = []
        for var in graph_inputs(outputs):
            if isinstance(var, SharedVariable):
                self.shared_inputs.append(var)

        inputs, outputs = replace_nominals_with_dummies(inputs, outputs)

        # The inputs should be `NominalVariable`s, so that graphs can be merged
        replacements = {}
        for n, v in enumerate(inputs):
            replacements[v] = NominalVariable(n, v.type)

        shared_vars = [
            NominalVariable(n, var.type)
            for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
        ]

        replacements.update(dict(zip(self.shared_inputs, shared_vars)))

        new = rebuild_collect_shared(
            cast(Sequence[Variable], outputs),
            inputs=inputs + shared_vars,
            replace=replacements,
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            (clone_d, update_d, update_expr, shared_inputs),
        ) = new

        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]

        self.lop_overrides = lop_overrides
        self.grad_overrides = grad_overrides
        self.rop_overrides = rop_overrides

        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name
Esempio n. 2
0
    def __init__(
        self,
        inputs,
        outputs,
        inline=False,
        lop_overrides="default",
        grad_overrides="default",
        rop_overrides="default",
        connection_pattern=None,
        name=None,
        **kwargs,
    ):
        if not isinstance(outputs, list):
            raise TypeError(f"outputs must be list, got {type(outputs)}")
        for i in inputs + outputs:
            if not isinstance(i, Variable):
                raise TypeError(
                    "inputs and outputs must be Variable instances", i)
        if "updates" in kwargs or "givens" in kwargs:
            raise TypeError("updates and givens are not allowed here")
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in graph_inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(zip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            [clone_d, update_d, update_expr, shared_inputs],
        ) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name