示例#1
0
def test_equal_computations():
    # This was a bug report by a Theano user.
    c = tensor.type_other.NoneConst
    assert equal_computations([c], [c])
    m = tensor.matrix()
    max_argmax1 = tensor.max_and_argmax(m)
    max_argmax2 = tensor.max_and_argmax(m)
    assert equal_computations(max_argmax1, max_argmax2)
示例#2
0
def is_same_graph(var1, var2, givens=None):
    """
    Return True iff Variables `var1` and `var2` perform the same computation.

    By 'performing the same computation', we mean that they must share the same
    graph, so that for instance this function will return False when comparing
    (x * (y * z)) with ((x * y) * z).

    The current implementation is not efficient since, when possible, it
    verifies equality by calling two different functions that are expected to
    return the same output. The goal is to verify this assumption, to
    eventually get rid of one of them in the future.

    Parameters
    ----------
    var1
        The first Variable to compare.
    var2
        The second Variable to compare.
    givens
        Similar to the `givens` argument of `theano.function`, it can be used
        to perform substitutions in the computational graph of `var1` and
        `var2`. This argument is associated to neither `var1` nor `var2`:
        substitutions may affect both graphs if the substituted variable
        is present in both.

    Examples
    --------

        ======  ======  ======  ======
        var1    var2    givens  output
        ======  ======  ======  ======
        x + 1   x + 1   {}      True
        x + 1   y + 1   {}      False
        x + 1   y + 1   {x: y}  True
        ======  ======  ======  ======

    """
    use_equal_computations = True

    if givens is None:
        givens = {}

    if not isinstance(givens, dict):
        givens = dict(givens)

    # Get result from the merge-based function.
    rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)

    if givens:
        # We need to build the `in_xs` and `in_ys` lists. To do this, we need
        # to be able to tell whether a variable belongs to the computational
        # graph of `var1` or `var2`.
        # The typical case we want to handle is when `to_replace` belongs to
        # one of these graphs, and `replace_by` belongs to the other one. In
        # other situations, the current implementation of `equal_computations`
        # is probably not appropriate, so we do not call it.
        ok = True
        in_xs = []
        in_ys = []
        # Compute the sets of all variables found in each computational graph.
        inputs_var = list(map(inputs, ([var1], [var2])))
        all_vars = [
            set(variables(v_i, v_o))
            for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
        ]

        def in_var(x, k):
            # Return True iff `x` is in computation graph of variable `vark`.
            return x in all_vars[k - 1]

        for to_replace, replace_by in givens.items():
            # Map a substitution variable to the computational graphs it
            # belongs to.
            inside = {
                v: [in_var(v, k) for k in (1, 2)]
                for v in (to_replace, replace_by)
            }
            if (inside[to_replace][0] and not inside[to_replace][1]
                    and inside[replace_by][1] and not inside[replace_by][0]):
                # Substitute variable in `var1` by one from `var2`.
                in_xs.append(to_replace)
                in_ys.append(replace_by)
            elif (inside[to_replace][1] and not inside[to_replace][0]
                  and inside[replace_by][0] and not inside[replace_by][1]):
                # Substitute variable in `var2` by one from `var1`.
                in_xs.append(replace_by)
                in_ys.append(to_replace)
            else:
                ok = False
                break
        if not ok:
            # We cannot directly use `equal_computations`.
            use_equal_computations = False
    else:
        in_xs = None
        in_ys = None
    if use_equal_computations:
        rval2 = equal_computations(xs=[var1],
                                   ys=[var2],
                                   in_xs=in_xs,
                                   in_ys=in_ys)
        assert rval2 == rval1
    return rval1
示例#3
0
def test_equal_computations():

    a, b = tensor.iscalars(2)

    with pytest.raises(ValueError):
        equal_computations([a], [a, b])

    assert equal_computations([a], [a])
    assert equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(1)])
    assert not equal_computations([b], [a])
    assert not equal_computations([tensor.as_tensor(1)], [tensor.as_tensor(2)])

    assert equal_computations([2], [2])
    assert equal_computations([np.r_[2, 1]], [np.r_[2, 1]])
    assert equal_computations([np.r_[2, 1]], [tensor.as_tensor(np.r_[2, 1])])
    assert equal_computations([tensor.as_tensor(np.r_[2, 1])], [np.r_[2, 1]])

    assert not equal_computations([2], [a])
    assert not equal_computations([np.r_[2, 1]], [a])
    assert not equal_computations([a], [2])
    assert not equal_computations([a], [np.r_[2, 1]])

    c = tensor.type_other.NoneConst
    assert equal_computations([c], [c])

    m = tensor.matrix()
    max_argmax1 = tensor.max_and_argmax(m)
    max_argmax2 = tensor.max_and_argmax(m)
    assert equal_computations(max_argmax1, max_argmax2)