コード例 #1
0
def clone(
    output, replace=None, strict=True, share_inputs=True, copy_inputs=DEPRECATED_ARG
):
    """
    Function that allows replacing subgraphs of a computational graph.

    It returns a copy of the initial subgraph with the corresponding
    substitutions.

    Parameters
    ----------
    output : Theano Variables (or Theano expressions)
        Theano expression that represents the computational graph.
    replace : dict
        Dictionary describing which subgraphs should be replaced by what.
    share_inputs : bool
        If True, use the same inputs (and shared variables) as the original
        graph. If False, clone them. Note that cloned shared variables still
        use the same underlying storage, so they will always have the same
        value.
    copy_inputs
        Deprecated, use share_inputs.

    """
    if copy_inputs is not DEPRECATED_ARG:
        warnings.warn(
            "In `clone()` function, the argument `copy_inputs` has been deprecated and renamed into `share_inputs`"
        )
        assert share_inputs  # since we used `copy_inputs` we should have default value for `share_inputs`
        share_inputs = copy_inputs

    if isinstance(replace, dict):
        items = list(replace.items())
    elif isinstance(replace, (list, tuple)):
        items = replace
    elif replace is None:
        items = []
    else:
        raise ValueError(
            (
                "replace is neither a dictionary, list, "
                f"tuple or None ! The value provided is {replace},"
                f"of type {type(replace)}"
            )
        )
    tmp_replace = [(x, x.type()) for x, y in items]
    new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
    _, _outs, _ = rebuild_collect_shared(
        output, [], tmp_replace, [], strict, share_inputs
    )

    # TODO Explain why we call it twice ?!
    _, outs, _ = rebuild_collect_shared(
        _outs, [], new_replace, [], strict, share_inputs
    )

    return outs
コード例 #2
0
    def __init__(
        self,
        inputs,
        outputs,
        inline=False,
        lop_overrides="default",
        grad_overrides="default",
        rop_overrides="default",
        connection_pattern=None,
        name=None,
        **kwargs,
    ):
        if not isinstance(outputs, list):
            raise TypeError(f"outputs must be list, got {type(outputs)}")
        for i in inputs + outputs:
            if not isinstance(i, gof.Variable):
                raise TypeError(
                    "inputs and outputs must be Variable instances", i)
        if "updates" in kwargs or "givens" in kwargs:
            raise TypeError("updates and givens are not allowed here")
        self.is_inline = inline
        # To correctly support shared variables the inner fct should
        # not see them. Otherwise there is a problem with the gradient.
        self.shared_inputs = [
            var for var in gof.graph.inputs(outputs)
            if isinstance(var, SharedVariable)
        ]
        shared_vars = [var.type() for var in self.shared_inputs]

        new = rebuild_collect_shared(
            outputs,
            inputs=inputs + shared_vars,
            replace=dict(zip(self.shared_inputs, shared_vars)),
            copy_inputs_over=False,
        )
        (
            local_inputs,
            local_outputs,
            [clone_d, update_d, update_expr, shared_inputs],
        ) = new
        assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
        assert len(local_outputs) == len(outputs)
        assert not update_d
        assert not update_expr
        assert not shared_inputs

        self.local_inputs = local_inputs
        self.local_outputs = local_outputs
        self.inputs = inputs
        self.outputs = outputs
        self.kwargs = kwargs
        self.input_types = [inp.type for inp in inputs]
        self.output_types = [out.type for out in outputs]
        if lop_overrides != "default":
            if grad_overrides != "default":
                raise ValueError(
                    "lop_overrides and grad_overrides are mutually exclusive")
            else:
                self.set_lop_overrides(lop_overrides)
                self._lop_type = "lop"
        elif grad_overrides != "default":
            self.set_lop_overrides(grad_overrides)
            self._lop_type = "grad"
        else:
            self.set_lop_overrides("default")
            self._lop_type = "lop"
        self.set_rop_overrides(rop_overrides)

        self._connection_pattern = connection_pattern

        if name is not None:
            assert isinstance(name, str), "name must be None or string object"
        self.name = name