コード例 #1
0
def test_argmax_pushdown():
    x = matrix()
    for sm in [softmax_graph, softmax_legacy]:
        # test that the max_and_argmax is pushed down if the max is not used
        out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
        fgraph = FunctionGraph([x], [out])
        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        # print 'AFTER'
        # for node in fgraph.toposort():
        # print node.op
        assert len(fgraph.toposort()) == 1
        assert isinstance(fgraph.toposort()[0].op, Argmax)
        assert check_stack_trace(fgraph, ops_to_check=Argmax)
        x = matrix()
        # test that the max_and_argmax is not pushed down if the max is used
        out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[0]
        fgraph = FunctionGraph([x], [out])

        assert hasattr(fgraph.outputs[0].tag, "trace")

        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        # print 'AFTER'
        # for node in fgraph.toposort():
        # print node.op
        assert len(fgraph.toposort()) == 3
        assert isinstance(fgraph.toposort()[0].op, Elemwise)
        assert isinstance(fgraph.toposort()[1].op, Softmax)
        assert isinstance(fgraph.toposort()[2].op, CAReduce)
        assert isinstance(fgraph.toposort()[2].op.scalar_op,
                          aesara.scalar.ScalarMaximum)
コード例 #2
0
ファイル: test_basic.py プロジェクト: mgorny/aesara
def test_equal_computations():

    a, b = iscalars(2)

    with pytest.raises(ValueError):
        equal_computations([a], [a, b])

    assert equal_computations([a], [a])
    assert equal_computations([at.as_tensor(1)], [at.as_tensor(1)])
    assert not equal_computations([b], [a])
    assert not equal_computations([at.as_tensor(1)], [at.as_tensor(2)])

    assert equal_computations([2], [2])
    assert equal_computations([np.r_[2, 1]], [np.r_[2, 1]])
    assert equal_computations([np.r_[2, 1]], [at.as_tensor(np.r_[2, 1])])
    assert equal_computations([at.as_tensor(np.r_[2, 1])], [np.r_[2, 1]])

    assert not equal_computations([2], [a])
    assert not equal_computations([np.r_[2, 1]], [a])
    assert not equal_computations([a], [2])
    assert not equal_computations([a], [np.r_[2, 1]])

    assert equal_computations([NoneConst], [NoneConst])

    m = matrix()
    max_argmax1 = max_and_argmax(m)
    max_argmax2 = max_and_argmax(m)
    assert equal_computations(max_argmax1, max_argmax2)
コード例 #3
0
def test_argmax_pushdown_bias():
    x = matrix()
    b = vector()

    out = argmax(softmax_with_bias(x, b), axis=-1)
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    types_to_check = (DimShuffle, Elemwise, Argmax)
    assert len(fgraph.toposort()) == 3

    for i, type in enumerate(types_to_check):
        assert isinstance(fgraph.toposort()[i].op, type)
    assert check_stack_trace(fgraph, ops_to_check=types_to_check)

    x = matrix()
    b = vector()
    out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    assert len(fgraph.toposort()) == 2
    assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
    assert isinstance(fgraph.toposort()[1].op, CAReduce)
    assert isinstance(fgraph.toposort()[1].op.scalar_op,
                      aesara.scalar.ScalarMaximum)
    assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
コード例 #4
0
    def test_optimization(self):
        # If we use only the max output, we should replace this op with
        # a faster one.
        mode = aesara.compile.mode.get_default_mode().including(
            "canonicalize", "fast_run")

        for axis in [0, 1, -1]:
            n = matrix()

            f = function([n], max_and_argmax(n, axis)[0], mode=mode)
            topo = f.maker.fgraph.toposort()
            assert len(topo) == 1
            assert isinstance(topo[0].op, CAReduce)

            f = function([n], max_and_argmax(n, axis), mode=mode)
            topo = f.maker.fgraph.toposort()
            assert len(topo) == 1
            assert isinstance(topo[0].op, MaxAndArgmax)