def test_argmax_pushdown(): x = matrix() for sm in [softmax_graph, softmax_legacy]: # test that the max_and_argmax is pushed down if the max is not used out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] fgraph = FunctionGraph([x], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 1 assert isinstance(fgraph.toposort()[0].op, Argmax) assert check_stack_trace(fgraph, ops_to_check=Argmax) x = matrix() # test that the max_and_argmax is not pushed down if the max is used out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[0] fgraph = FunctionGraph([x], [out]) assert hasattr(fgraph.outputs[0].tag, "trace") optdb.query(OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 3 assert isinstance(fgraph.toposort()[0].op, Elemwise) assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[2].op, CAReduce) assert isinstance(fgraph.toposort()[2].op.scalar_op, aesara.scalar.ScalarMaximum)
def test_argmax_pushdown_bias(): x = matrix() b = vector() out = argmax(softmax_with_bias(x, b), axis=-1) fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) types_to_check = (DimShuffle, Elemwise, Argmax) assert len(fgraph.toposort()) == 3 for i, type in enumerate(types_to_check): assert isinstance(fgraph.toposort()[i].op, type) assert check_stack_trace(fgraph, ops_to_check=types_to_check) x = matrix() b = vector() out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] fgraph = FunctionGraph([x, b], [out]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert len(fgraph.toposort()) == 2 assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[1].op, CAReduce) assert isinstance(fgraph.toposort()[1].op.scalar_op, aesara.scalar.ScalarMaximum) assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
def test_xent_thing_int32(self): x = matrix("x") y = lvector("y") yi = aet.cast(y, "int32") expressions = [ aet_sum(-log(softmax(x)[aet.arange(yi.shape[0]), yi])), -aet_sum(log(softmax(x)[aet.arange(yi.shape[0]), yi])), -aet_sum(log(softmax(x))[aet.arange(yi.shape[0]), yi]), aet_sum(-log(softmax(x))[aet.arange(yi.shape[0]), yi]), ] for expr in expressions: fgraph = FunctionGraph([x, y], [expr]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 assert crossentropy_softmax_argmax_1hot_with_bias in ops assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] # Also verify the gradient wrt x fgraph = FunctionGraph([x, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 3 assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops
def test_softmax_optimizations(self): x = matrix("x") one_of_n = lvector("one_of_n") op = crossentropy_categorical_1hot # xe = op(x, one_of_n) fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)]) assert fgraph.outputs[0].owner.op == op optdb.query(OPT_FAST_RUN).optimize(fgraph) assert fgraph.outputs[ 0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias2(self): x = matrix("x") b = vector("b") c = vector("c") one_of_n = lvector("one_of_n") op = crossentropy_categorical_1hot fgraph = FunctionGraph([x, b, c, one_of_n], [op(softmax_legacy(add(x, b, c)), one_of_n)]) assert fgraph.outputs[0].owner.op == op optdb.query(OPT_FAST_RUN).optimize(fgraph) assert len(fgraph.toposort()) == 2 assert fgraph.outputs[ 0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_grad_optimizations(self): x = matrix("x") one_of_n = lvector("one_of_n") op = crossentropy_categorical_1hot xe = op(softmax_legacy(x), one_of_n) sum_xe = aet_sum(xe) g_x = grad(sum_xe, x) fgraph = FunctionGraph([x, one_of_n], [g_x]) assert check_stack_trace( fgraph, ops_to_check=[ crossentropy_softmax_1hot_with_bias_dx, softmax_legacy ], ) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = {node.op for node in fgraph.toposort()} assert crossentropy_softmax_argmax_1hot_with_bias not in ops assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops
def test_logsoftmax_grad_true_div_elemwise(self): # Checks that the gradient of an expression similar to a log(softmax) # but with a different elemwise operation than true_div is not # optimized. x = matrix("x") y = log(softmax(x)) g = grad(y.sum(), x) softmax_grad_node = g.owner assert softmax_grad_node.op == softmax_grad_legacy true_div_node = softmax_grad_node.inputs[0].owner assert true_div_node.op == true_div # We replace the elemwise true_div op by an elemwise add. new_g = softmax_grad_legacy(add(*true_div_node.inputs), softmax_grad_node.inputs[1]) fgraph = FunctionGraph([x], [new_g]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self): x = matrix("x") y = lvector("y") a = scalar("a") def validate_grad_graph(func): # The graph of the gradient should not have softmaxgrad anymore has_cx1hotdx = False has_softmax = False has_softmaxdx = False for node in func.maker.fgraph.toposort(): if node.op == crossentropy_softmax_1hot_with_bias_dx: has_cx1hotdx = True if node.op == softmax_legacy: has_softmax = True if node.op == softmax_grad_legacy: has_softmaxdx = True assert has_cx1hotdx assert has_softmax assert not has_softmaxdx # Cases to test expressions = [ a * aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])), -a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])), a * (-aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y]))), a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])), a * aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]), -a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]), a * (-aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y])), a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]), a * mean(-log(softmax(x)[aet.arange(y.shape[0]), y])), -a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])), a * (-mean(log(softmax(x)[aet.arange(y.shape[0]), y]))), a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])), a * mean(-log(softmax(x))[aet.arange(y.shape[0]), y]), -a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]), a * (-mean(log(softmax(x))[aet.arange(y.shape[0]), y])), a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]), ] for expr in expressions: fgraph = FunctionGraph([x, y, a], [expr]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert 5 <= len(fgraph.toposort()) <= 10 ops = {node.op for node in fgraph.toposort()} assert crossentropy_softmax_argmax_1hot_with_bias in ops assert softmax_legacy not in ops # Verify the gradient wrt x fgraph = FunctionGraph([x, y, a], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert 3 <= len(fgraph.toposort()) <= 6 ops = {node.op for node in fgraph.toposort()} assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops # Verify the gradient when providing output gradient fgraph = FunctionGraph( [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})]) optdb.query(OPT_FAST_RUN).optimize(fgraph) assert 6 <= len(fgraph.toposort()) <= 8 ops = {node.op for node in fgraph.toposort()} assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops
def test_get_rid_of_advanced_indexing_version_of_xent(self): x = matrix("x") b = vector("b") y = lvector("y") # Basic case expressions = [ aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])), -aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])), -aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]), aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]), ] for expr in expressions: fgraph = FunctionGraph([x, y], [expr]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 4 assert crossentropy_softmax_argmax_1hot_with_bias in ops assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] # Also verify the gradient wrt x fgraph = FunctionGraph([x, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops # Test that a biased softmax is optimized correctly bias_expressions = [ aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])), -aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])), -aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]), aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]), ] for expr in bias_expressions: fgraph = FunctionGraph([x, b, y], [expr, x]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 # [big_op, sum] assert crossentropy_softmax_argmax_1hot_with_bias in ops fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_with_bias in ops assert softmax_grad_legacy not in ops # Test that using "mean" instead of sum works, too mean_expressions = [ mean(-log(softmax(x)[aet.arange(y.shape[0]), y])), -mean(log(softmax(x)[aet.arange(y.shape[0]), y])), -mean(log(softmax(x))[aet.arange(y.shape[0]), y]), mean(-log(softmax(x))[aet.arange(y.shape[0]), y]), ] for expr in mean_expressions: fgraph = FunctionGraph([x, y], [expr]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 6 assert crossentropy_softmax_argmax_1hot_with_bias in ops assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] fgraph = FunctionGraph([x, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 # there's an extra dimshuffle in there # but I can't think of a good rule to get rid of it assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_legacy in ops assert softmax_grad_legacy not in ops mean_bias_expressions = [ mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])), -mean(log(softmax(b + x)[aet.arange(y.shape[0]), y])), -mean(log(softmax(x + b))[aet.arange(y.shape[0]), y]), mean(-log(softmax(b + x))[aet.arange(y.shape[0]), y]), ] for expr in mean_bias_expressions: fgraph = FunctionGraph([x, b, y], [expr]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 4 assert crossentropy_softmax_argmax_1hot_with_bias in ops assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 assert crossentropy_softmax_1hot_with_bias_dx in ops assert softmax_with_bias in ops assert softmax_grad_legacy not in ops