def test_argmax_pushdown(): x = matrix() for sm in [softmax_graph, softmax_legacy]: # test that the max_and_argmax is pushed down if the max is not used out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] fgraph = FunctionGraph([x], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 1 assert isinstance(fgraph.toposort()[0].op, Argmax) assert check_stack_trace(fgraph, ops_to_check=Argmax) x = matrix() # test that the max_and_argmax is not pushed down if the max is used out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[0] fgraph = FunctionGraph([x], [out]) assert hasattr(fgraph.outputs[0].tag, "trace") optdb.query(OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 3 assert isinstance(fgraph.toposort()[0].op, Elemwise) assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[2].op, CAReduce) assert isinstance(fgraph.toposort()[2].op.scalar_op, aesara.scalar.ScalarMaximum)
def test_equal_computations(): a, b = iscalars(2) with pytest.raises(ValueError): equal_computations([a], [a, b]) assert equal_computations([a], [a]) assert equal_computations([at.as_tensor(1)], [at.as_tensor(1)]) assert not equal_computations([b], [a]) assert not equal_computations([at.as_tensor(1)], [at.as_tensor(2)]) assert equal_computations([2], [2]) assert equal_computations([np.r_[2, 1]], [np.r_[2, 1]]) assert equal_computations([np.r_[2, 1]], [at.as_tensor(np.r_[2, 1])]) assert equal_computations([at.as_tensor(np.r_[2, 1])], [np.r_[2, 1]]) assert not equal_computations([2], [a]) assert not equal_computations([np.r_[2, 1]], [a]) assert not equal_computations([a], [2]) assert not equal_computations([a], [np.r_[2, 1]]) assert equal_computations([NoneConst], [NoneConst]) m = matrix() max_argmax1 = max_and_argmax(m) max_argmax2 = max_and_argmax(m) assert equal_computations(max_argmax1, max_argmax2)
def test_argmax_pushdown_bias(): x = matrix() b = vector() out = argmax(softmax_with_bias(x, b), axis=-1) fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) types_to_check = (DimShuffle, Elemwise, Argmax) assert len(fgraph.toposort()) == 3 for i, type in enumerate(types_to_check): assert isinstance(fgraph.toposort()[i].op, type) assert check_stack_trace(fgraph, ops_to_check=types_to_check) x = matrix() b = vector() out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert len(fgraph.toposort()) == 2 assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[1].op, CAReduce) assert isinstance(fgraph.toposort()[1].op.scalar_op, aesara.scalar.ScalarMaximum) assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
def test_optimization(self): # If we use only the max output, we should replace this op with # a faster one. mode = aesara.compile.mode.get_default_mode().including( "canonicalize", "fast_run") for axis in [0, 1, -1]: n = matrix() f = function([n], max_and_argmax(n, axis)[0], mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 assert isinstance(topo[0].op, CAReduce) f = function([n], max_and_argmax(n, axis), mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 assert isinstance(topo[0].op, MaxAndArgmax)