def test_argmax_pushdown_bias(): x = matrix() b = vector() out = argmax(softmax_with_bias(x, b), axis=-1) fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) types_to_check = (DimShuffle, Elemwise, Argmax) assert len(fgraph.toposort()) == 3 for i, type in enumerate(types_to_check): assert isinstance(fgraph.toposort()[i].op, type) assert check_stack_trace(fgraph, ops_to_check=types_to_check) x = matrix() b = vector() out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert len(fgraph.toposort()) == 2 assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[1].op, CAReduce) assert isinstance(fgraph.toposort()[1].op.scalar_op, aesara.scalar.ScalarMaximum) assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
def f(a, b): return softmax_with_bias(a, b)[:, 3]