Ejemplo n.º 1
0
    def replace(self, r, new_r, reason=None, verbose=None):
        """
        This is the main interface to manipulate the subgraph in FunctionGraph.
        For every node that uses r as input, makes it use new_r instead.

        """
        if verbose is None:
            verbose = config.optimizer_verbose
        if verbose:
            print(reason, r, new_r)
        if hasattr(r, "fgraph") and r.fgraph is not self:
            raise Exception(
                "Cannot replace %s because it does not belong "
                "to this FunctionGraph" % r,
                str(reason),
            )
        if r.type != new_r.type:
            new_r2 = r.type.convert_variable(new_r)
            # We still make sure that the type converts correctly
            if new_r2 is None or new_r2.type != r.type:
                done = dict()
                used_ids = dict()
                old = theano.compile.debugmode.debugprint(
                    r,
                    prefix="  ",
                    depth=6,
                    file=StringIO(),
                    done=done,
                    print_type=True,
                    used_ids=used_ids,
                ).getvalue()
                new = theano.compile.debugmode.debugprint(
                    new_r,
                    prefix="  ",
                    depth=6,
                    file=StringIO(),
                    done=done,
                    print_type=True,
                    used_ids=used_ids,
                ).getvalue()
                raise toolbox.BadOptimization(
                    r,
                    new_r,
                    None,
                    None,
                    str(reason) +
                    ". The type of the replacement must be the same.",
                    old,
                    new,
                )
            new_r = new_r2
        if r not in self.variables:
            # this variable isn't in the graph... don't raise an
            # exception here, just return silently because it makes it
            # easier to implement some optimizations for
            # multiple-output ops
            return

        if theano.config.compute_test_value != "off":
            try:
                tval = theano.gof.op.get_test_value(r)
                new_tval = theano.gof.op.get_test_value(new_r)
            except TestValueError:
                pass
            else:
                tval_shape = getattr(tval, "shape", None)
                new_tval_shape = getattr(new_tval, "shape", None)
                if tval_shape != new_tval_shape:
                    raise AssertionError(
                        "The replacement variable has a test value with "
                        "a shape different from the original variable's "
                        "test value. Original: %s, new: %s" %
                        (tval_shape, new_tval_shape),
                        r,
                        new_r,
                        str(reason),
                    )

        for node, i in list(r.clients):  # copy the client list for iteration
            assert (node == "output"
                    and self.outputs[i] is r) or (node.inputs[i] is r)
            self.change_input(node, i, new_r, reason=reason)
Ejemplo n.º 2
0
    def replace(self, var, new_var, reason=None, verbose=None):
        """Replace a variable in the `FunctionGraph`.

        This is the main interface to manipulate the subgraph in `FunctionGraph`.
        For every node that uses `var` as input, makes it use `new_var` instead.

        Parameters:
        ----------
        var : theano.gof.graph.Variable
            The variable to be replaced.
        new_var : theano.gof.graph.Variable
            The variable to replace `var`.
        reason : str
            The name of the optimization or operation in progress.
        verbose : bool
            Print `reason`, `var`, and `new_var`.

        """
        if verbose is None:
            verbose = config.optimizer_verbose
        if verbose:
            print(reason, var, new_var)

        if var.type != new_var.type:
            new_var_2 = var.type.convert_variable(new_var)
            # We still make sure that the type converts correctly
            if new_var_2 is None or new_var_2.type != var.type:
                done = dict()
                used_ids = dict()
                old = theano.compile.debugmode.debugprint(
                    var,
                    prefix="  ",
                    depth=6,
                    file=StringIO(),
                    done=done,
                    print_type=True,
                    used_ids=used_ids,
                ).getvalue()
                new = theano.compile.debugmode.debugprint(
                    new_var,
                    prefix="  ",
                    depth=6,
                    file=StringIO(),
                    done=done,
                    print_type=True,
                    used_ids=used_ids,
                ).getvalue()
                raise toolbox.BadOptimization(
                    var,
                    new_var,
                    None,
                    None,
                    str(reason) +
                    ". The type of the replacement must be the same.",
                    old,
                    new,
                )
            new_var = new_var_2

        if var not in self.variables:
            # this variable isn't in the graph... don't raise an
            # exception here, just return silently because it makes it
            # easier to implement some optimizations for
            # multiple-output ops
            return

        if theano.config.compute_test_value != "off":
            try:
                tval = theano.gof.op.get_test_value(var)
                new_tval = theano.gof.op.get_test_value(new_var)
            except TestValueError:
                pass
            else:
                tval_shape = getattr(tval, "shape", None)
                new_tval_shape = getattr(new_tval, "shape", None)
                if tval_shape != new_tval_shape:
                    raise AssertionError(
                        "The replacement variable has a test value with "
                        "a shape different from the original variable's "
                        f"test value. Original: {tval_shape}, new: {new_tval_shape}"
                    )

        for node, i in list(self.clients[var]):
            assert (node == "output"
                    and self.outputs[i] is var) or (node.inputs[i] is var)
            self.change_input(node, i, new_var, reason=reason)