示例#1
0
    def check_integrity(self):
        """
        Call this for a diagnosis if things go awry.

        """
        nodes = graph.ops(self.inputs, self.outputs)
        if self.apply_nodes != nodes:
            missing = nodes.difference(self.apply_nodes)
            excess = self.apply_nodes.difference(nodes)
            raise Exception(
                "The nodes are inappropriately cached. missing, in excess: ",
                missing,
                excess,
            )
        for node in nodes:
            if node.fgraph is not self:
                raise Exception("Node should belong to the FunctionGraph.",
                                node)
            for i, variable in enumerate(node.inputs):
                if variable.fgraph is not self:
                    raise Exception(
                        "Input of node should belong to the FunctionGraph.",
                        variable,
                        (node, i),
                    )
                if (node, i) not in variable.clients:
                    raise Exception("Inconsistent clients list.", (node, i),
                                    variable.clients)
        variables = set(graph.variables(self.inputs, self.outputs))
        if set(self.variables) != variables:
            missing = variables.difference(self.variables)
            excess = self.variables.difference(variables)
            raise Exception(
                "The variables are inappropriately cached. missing, in excess: ",
                missing,
                excess,
            )
        for variable in variables:
            if (variable.owner is None and variable not in self.inputs
                    and not isinstance(variable, graph.Constant)):
                raise Exception("Undeclared input.", variable)
            if variable.fgraph is not self:
                raise Exception("Variable should belong to the FunctionGraph.",
                                variable)
            for node, i in variable.clients:
                if node == "output":
                    if self.outputs[i] is not variable:
                        raise Exception("Inconsistent clients list.", variable,
                                        self.outputs[i])
                    continue
                if node not in nodes:
                    raise Exception("Client not in FunctionGraph.", variable,
                                    (node, i))
                if node.inputs[i] is not variable:
                    raise Exception("Inconsistent clients list.", variable,
                                    node.inputs[i])
示例#2
0
文件: fg.py 项目: aalmah/Theano
    def check_integrity(self):
        """
        WRITEME

        Call this for a diagnosis if things go awry.

        """
        nodes = graph.ops(self.inputs, self.outputs)
        if self.apply_nodes != nodes:
            missing = nodes.difference(self.apply_nodes)
            excess = self.apply_nodes.difference(nodes)
            raise Exception(
                "The nodes are inappropriately cached. missing, in excess: ",
                missing, excess)
        for node in nodes:
            if node.fgraph is not self:
                raise Exception("Node should belong to the FunctionGraph.",
                                node)
            for i, variable in enumerate(node.inputs):
                if variable.fgraph is not self:
                    raise Exception(
                        "Input of node should belong to the FunctionGraph.",
                        variable, (node, i))
                if (node, i) not in variable.clients:
                    raise Exception("Inconsistent clients list.",
                                    (node, i), variable.clients)
        variables = set(graph.variables(self.inputs, self.outputs))
        if set(self.variables) != variables:
            missing = variables.difference(self.variables)
            excess = self.variables.difference(variables)
            raise Exception(
                "The variables are inappropriately cached. missing, in excess: ",
                missing, excess)
        for variable in variables:
            if (variable.owner is None and
                    variable not in self.inputs and
                    not isinstance(variable, graph.Constant)):
                raise Exception("Undeclared input.", variable)
            if variable.fgraph is not self:
                raise Exception("Variable should belong to the FunctionGraph.",
                                variable)
            for node, i in variable.clients:
                if node == 'output':
                    if self.outputs[i] is not variable:
                        raise Exception("Inconsistent clients list.",
                                        variable, self.outputs[i])
                    continue
                if node not in nodes:
                    raise Exception("Client not in FunctionGraph.",
                                    variable, (node, i))
                if node.inputs[i] is not variable:
                    raise Exception("Inconsistent clients list.",
                                    variable, node.inputs[i])
示例#3
0
def test_variables_and_orphans():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    vars_res = variables([r1, r2], [o2])
    orphans_res = orphans([r1, r2], [o2])

    vars_res_list = list(vars_res)
    orphans_res_list = list(orphans_res)
    assert vars_res_list == [o2, o1, r3, r2, r1]
    assert orphans_res_list == [r3]
示例#4
0
def is_same_graph(var1, var2, givens=None):
    """
    Return True iff Variables `var1` and `var2` perform the same computation.

    By 'performing the same computation', we mean that they must share the same
    graph, so that for instance this function will return False when comparing
    (x * (y * z)) with ((x * y) * z).

    The current implementation is not efficient since, when possible, it
    verifies equality by calling two different functions that are expected to
    return the same output. The goal is to verify this assumption, to
    eventually get rid of one of them in the future.

    Parameters
    ----------
    var1
        The first Variable to compare.
    var2
        The second Variable to compare.
    givens
        Similar to the `givens` argument of `theano.function`, it can be used
        to perform substitutions in the computational graph of `var1` and
        `var2`. This argument is associated to neither `var1` nor `var2`:
        substitutions may affect both graphs if the substituted variable
        is present in both.

    Examples
    --------

        ======  ======  ======  ======
        var1    var2    givens  output
        ======  ======  ======  ======
        x + 1   x + 1   {}      True
        x + 1   y + 1   {}      False
        x + 1   y + 1   {x: y}  True
        ======  ======  ======  ======

    """
    use_equal_computations = True

    if givens is None:
        givens = {}

    if not isinstance(givens, dict):
        givens = dict(givens)

    # Get result from the merge-based function.
    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)

    if givens:
        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
        # to be able to tell whether a variable belongs to the computational
        # graph of `var1` or `var2`.
        # The typical case we want to handle is when `to_replace` belongs to
        # one of these graphs, and `replace_by` belongs to the other one. In
        # other situations, the current implementation of `equal_computations`
        # is probably not appropriate, so we do not call it.
        ok = True
        in_xs = []
        in_ys = []
        # Compute the sets of all variables found in each computational graph.
        inputs_var = list(map(inputs, ([var1], [var2])))
        all_vars = [
            set(variables(v_i, v_o))
            for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
        ]

        def in_var(x, k):
            # Return True iff `x` is in computation graph of variable `vark`.
            return x in all_vars[k - 1]

        for to_replace, replace_by in givens.items():
            # Map a substitution variable to the computational graphs it
            # belongs to.
            inside = {
                v: [in_var(v, k) for k in (1, 2)]
                for v in (to_replace, replace_by)
            }
            if (inside[to_replace][0] and not inside[to_replace][1]
                    and inside[replace_by][1] and not inside[replace_by][0]):
                # Substitute variable in `var1` by one from `var2`.
                in_xs.append(to_replace)
                in_ys.append(replace_by)
            elif (inside[to_replace][1] and not inside[to_replace][0]
                  and inside[replace_by][0] and not inside[replace_by][1]):
                # Substitute variable in `var2` by one from `var1`.
                in_xs.append(replace_by)
                in_ys.append(to_replace)
            else:
                ok = False
                break
        if not ok:
            # We cannot directly use `equal_computations`.
            use_equal_computations = False
    else:
        in_xs = None
        in_ys = None
    if use_equal_computations:
        rval2 = equal_computations(xs=[var1],
                                   ys=[var2],
                                   in_xs=in_xs,
                                   in_ys=in_ys)
        assert rval2 == rval1
    return rval1