示例#1
0
def test_argmax_pushdown_bias():
    x = matrix()
    b = vector()

    out = argmax(softmax_with_bias(x, b), axis=-1)
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    types_to_check = (DimShuffle, Elemwise, Argmax)
    assert len(fgraph.toposort()) == 3

    for i, type in enumerate(types_to_check):
        assert isinstance(fgraph.toposort()[i].op, type)
    assert check_stack_trace(fgraph, ops_to_check=types_to_check)

    x = matrix()
    b = vector()
    out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    assert len(fgraph.toposort()) == 2
    assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
    assert isinstance(fgraph.toposort()[1].op, CAReduce)
    assert isinstance(fgraph.toposort()[1].op.scalar_op,
                      aesara.scalar.ScalarMaximum)
    assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
示例#2
0
 def f(a, b):
     return softmax_with_bias(a, b)[:, 3]