Example #1
0
def test_argmax_pushdown():
    x = matrix()
    for sm in [softmax_graph, softmax_legacy]:
        # test that the max_and_argmax is pushed down if the max is not used
        out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
        fgraph = FunctionGraph([x], [out])
        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        # print 'AFTER'
        # for node in fgraph.toposort():
        # print node.op
        assert len(fgraph.toposort()) == 1
        assert isinstance(fgraph.toposort()[0].op, Argmax)
        assert check_stack_trace(fgraph, ops_to_check=Argmax)
        x = matrix()
        # test that the max_and_argmax is not pushed down if the max is used
        out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[0]
        fgraph = FunctionGraph([x], [out])

        assert hasattr(fgraph.outputs[0].tag, "trace")

        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        # print 'AFTER'
        # for node in fgraph.toposort():
        # print node.op
        assert len(fgraph.toposort()) == 3
        assert isinstance(fgraph.toposort()[0].op, Elemwise)
        assert isinstance(fgraph.toposort()[1].op, Softmax)
        assert isinstance(fgraph.toposort()[2].op, CAReduce)
        assert isinstance(fgraph.toposort()[2].op.scalar_op,
                          aesara.scalar.ScalarMaximum)
Example #2
0
def test_argmax_pushdown_bias():
    x = matrix()
    b = vector()

    out = argmax(softmax_with_bias(x, b), axis=-1)
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    types_to_check = (DimShuffle, Elemwise, Argmax)
    assert len(fgraph.toposort()) == 3

    for i, type in enumerate(types_to_check):
        assert isinstance(fgraph.toposort()[i].op, type)
    assert check_stack_trace(fgraph, ops_to_check=types_to_check)

    x = matrix()
    b = vector()
    out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
    fgraph = FunctionGraph([x, b], [out])

    optdb.query(OPT_FAST_RUN).optimize(fgraph)

    assert len(fgraph.toposort()) == 2
    assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
    assert isinstance(fgraph.toposort()[1].op, CAReduce)
    assert isinstance(fgraph.toposort()[1].op.scalar_op,
                      aesara.scalar.ScalarMaximum)
    assert check_stack_trace(fgraph, ops_to_check=(SoftmaxWithBias, CAReduce))
Example #3
0
    def test_xent_thing_int32(self):
        x = matrix("x")
        y = lvector("y")
        yi = aet.cast(y, "int32")
        expressions = [
            aet_sum(-log(softmax(x)[aet.arange(yi.shape[0]), yi])),
            -aet_sum(log(softmax(x)[aet.arange(yi.shape[0]), yi])),
            -aet_sum(log(softmax(x))[aet.arange(yi.shape[0]), yi]),
            aet_sum(-log(softmax(x))[aet.arange(yi.shape[0]), yi]),
        ]

        for expr in expressions:
            fgraph = FunctionGraph([x, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 5
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            # Also verify the gradient wrt x
            fgraph = FunctionGraph([x, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 3
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops
Example #4
0
    def test_softmax_optimizations(self):
        x = matrix("x")
        one_of_n = lvector("one_of_n")
        op = crossentropy_categorical_1hot
        # xe = op(x, one_of_n)

        fgraph = FunctionGraph([x, one_of_n],
                               [op(softmax_legacy(x), one_of_n)])
        assert fgraph.outputs[0].owner.op == op

        optdb.query(OPT_FAST_RUN).optimize(fgraph)
        assert fgraph.outputs[
            0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
Example #5
0
    def test_softmax_optimizations_w_bias2(self):
        x = matrix("x")
        b = vector("b")
        c = vector("c")
        one_of_n = lvector("one_of_n")
        op = crossentropy_categorical_1hot

        fgraph = FunctionGraph([x, b, c, one_of_n],
                               [op(softmax_legacy(add(x, b, c)), one_of_n)])
        assert fgraph.outputs[0].owner.op == op

        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        assert len(fgraph.toposort()) == 2
        assert fgraph.outputs[
            0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
Example #6
0
    def test_softmax_grad_optimizations(self):
        x = matrix("x")
        one_of_n = lvector("one_of_n")
        op = crossentropy_categorical_1hot
        xe = op(softmax_legacy(x), one_of_n)
        sum_xe = aet_sum(xe)
        g_x = grad(sum_xe, x)
        fgraph = FunctionGraph([x, one_of_n], [g_x])
        assert check_stack_trace(
            fgraph,
            ops_to_check=[
                crossentropy_softmax_1hot_with_bias_dx, softmax_legacy
            ],
        )

        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        ops = {node.op for node in fgraph.toposort()}
        assert crossentropy_softmax_argmax_1hot_with_bias not in ops
        assert crossentropy_softmax_1hot_with_bias_dx in ops
        assert softmax_legacy in ops
        assert softmax_grad_legacy not in ops
Example #7
0
    def test_logsoftmax_grad_true_div_elemwise(self):
        # Checks that the gradient of an expression similar to a log(softmax)
        # but with a different elemwise operation than true_div is not
        # optimized.

        x = matrix("x")
        y = log(softmax(x))
        g = grad(y.sum(), x)

        softmax_grad_node = g.owner
        assert softmax_grad_node.op == softmax_grad_legacy
        true_div_node = softmax_grad_node.inputs[0].owner
        assert true_div_node.op == true_div

        # We replace the elemwise true_div op by an elemwise add.
        new_g = softmax_grad_legacy(add(*true_div_node.inputs),
                                    softmax_grad_node.inputs[1])

        fgraph = FunctionGraph([x], [new_g])
        optdb.query(OPT_FAST_RUN).optimize(fgraph)

        assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
Example #8
0
    def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
        x = matrix("x")
        y = lvector("y")
        a = scalar("a")

        def validate_grad_graph(func):
            # The graph of the gradient should not have softmaxgrad anymore
            has_cx1hotdx = False
            has_softmax = False
            has_softmaxdx = False
            for node in func.maker.fgraph.toposort():
                if node.op == crossentropy_softmax_1hot_with_bias_dx:
                    has_cx1hotdx = True
                if node.op == softmax_legacy:
                    has_softmax = True
                if node.op == softmax_grad_legacy:
                    has_softmaxdx = True

            assert has_cx1hotdx
            assert has_softmax
            assert not has_softmaxdx

        # Cases to test
        expressions = [
            a * aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * (-aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y]))),
            a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]),
            -a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * (-aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y])),
            a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * mean(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * (-mean(log(softmax(x)[aet.arange(y.shape[0]), y]))),
            a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * mean(-log(softmax(x))[aet.arange(y.shape[0]), y]),
            -a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * (-mean(log(softmax(x))[aet.arange(y.shape[0]), y])),
            a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in expressions:
            fgraph = FunctionGraph([x, y, a], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 5 <= len(fgraph.toposort()) <= 10

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert softmax_legacy not in ops

            # Verify the gradient wrt x
            fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 3 <= len(fgraph.toposort()) <= 6

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

            # Verify the gradient when providing output gradient
            fgraph = FunctionGraph(
                [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 6 <= len(fgraph.toposort()) <= 8

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops
Example #9
0
    def test_get_rid_of_advanced_indexing_version_of_xent(self):
        x = matrix("x")
        b = vector("b")
        y = lvector("y")

        # Basic case
        expressions = [
            aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]
        for expr in expressions:

            fgraph = FunctionGraph([x, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 4
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            # Also verify the gradient wrt x
            fgraph = FunctionGraph([x, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

        # Test that a biased softmax is optimized correctly
        bias_expressions = [
            aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
            aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in bias_expressions:
            fgraph = FunctionGraph([x, b, y], [expr, x])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2  # [big_op, sum]
            assert crossentropy_softmax_argmax_1hot_with_bias in ops

            fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_with_bias in ops
            assert softmax_grad_legacy not in ops

        # Test that using "mean" instead of sum works, too
        mean_expressions = [
            mean(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
            mean(-log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in mean_expressions:

            fgraph = FunctionGraph([x, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 6
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            fgraph = FunctionGraph([x, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 5
            # there's an extra dimshuffle in there
            # but I can't think of a good rule to get rid of it
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

        mean_bias_expressions = [
            mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
            mean(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in mean_bias_expressions:

            fgraph = FunctionGraph([x, b, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 4
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 5
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_with_bias in ops
            assert softmax_grad_legacy not in ops