Esempio n. 1
0
def local_abstract_batch_norm_train_grad(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormTrainGrad):
        return None

    x, dy, scale, x_mean, x_invstd, epsilon = node.inputs
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if (not isinstance(x.type, TensorType)
            or not isinstance(dy.type, TensorType)
            or not isinstance(scale.type, TensorType)
            or not isinstance(x_mean.type, TensorType)
            or not isinstance(x_invstd.type, TensorType)
            or not isinstance(epsilon.type, TensorType)):
        return None

    x_diff = x - x_mean
    mean_dy_x_diff = mean(dy * x_diff, axis=axes, keepdims=True)
    c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd**3))

    g_wrt_inputs = scale * (c - mean(c, axis=axes, keepdims=True))
    g_wrt_scale = aet_sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
    g_wrt_bias = aet_sum(dy, axis=axes, keepdims=True)
    results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]

    results = [
        aet.patternbroadcast(r, r_orig.broadcastable)
        for (r, r_orig) in zip(results, node.outputs)
    ]

    for var in aesara.graph.basic.vars_between(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Esempio n. 2
0
def test_GpuCrossentropySoftmaxArgmax1HotWithBias():
    # This is basic test for GpuCrossentropySoftmaxArgmax1HotWithBias
    # We check that we loop when their is too much threads

    n_in = 1000
    batch_size = 4097
    n_out = 1250

    if not isinstance(mode_with_gpu, aesara.compile.debugmode.DebugMode):
        n_in = 4098
        n_out = 4099

    y = lvector("y")

    b = fvector("b")

    # we precompute the dot with big shape before to allow the test of
    # GpuCrossentropySoftmax1HotWithBiasDx to don't fail with the error
    # (the launch timed out and was terminated) on GPU card not
    # powerful enough. We need the big shape to check for corner
    # case.
    dot_result = fmatrix("dot_result")

    xx = np.asarray(np.random.rand(batch_size, n_in), dtype=np.float32)
    yy = np.ones((batch_size, ), dtype="int32")
    b_values = np.zeros((n_out, ), dtype="float32")
    W_values = np.asarray(np.random.rand(n_in, n_out), dtype="float32")

    dot_value = np.asarray(np.dot(xx, W_values), dtype="float32")
    del W_values
    p_y_given_x = aesara.tensor.nnet.softmax(dot_result + b)
    y_pred = argmax(p_y_given_x, axis=-1)
    loss = -mean(log(p_y_given_x)[aet.arange(y.shape[0]), y])
    dW = grad(loss, dot_result)
    classify = aesara.function(inputs=[y, b, dot_result],
                               outputs=[loss, y_pred, dW],
                               mode=mode_without_gpu)
    classify_gpu = aesara.function(inputs=[y, b, dot_result],
                                   outputs=[loss, y_pred, dW],
                                   mode=mode_with_gpu)

    assert any([
        isinstance(node.op,
                   aesara.tensor.nnet.CrossentropySoftmaxArgmax1HotWithBias)
        for node in classify.maker.fgraph.toposort()
    ])
    assert any([
        isinstance(node.op, GpuCrossentropySoftmaxArgmax1HotWithBias)
        for node in classify_gpu.maker.fgraph.toposort()
    ])

    out = classify(yy, b_values, dot_value)
    gout = classify_gpu(yy, b_values, dot_value)

    assert len(out) == len(gout) == 3
    utt.assert_allclose(out[0], gout[0])
    utt.assert_allclose(out[2], gout[2], atol=3e-6)
    utt.assert_allclose(out[1], gout[1])
Esempio n. 3
0
 def setup_gpu_op(self,
                  activations,
                  labels,
                  input_length,
                  compute_grad=True):
     gpu_ctc_cost = gpu_ctc(activations, labels, input_length)
     outputs = [gpu_ctc_cost]
     if compute_grad:
         # Symbolic gradient of CTC cost
         gpu_ctc_grad = grad(mean(gpu_ctc_cost), activations)
         outputs += [gpu_ctc_grad]
     return aesara.function([], outputs, mode=mode_with_gpu)
Esempio n. 4
0
    def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
        x = matrix("x")
        y = lvector("y")
        a = scalar("a")

        def validate_grad_graph(func):
            # The graph of the gradient should not have softmaxgrad anymore
            has_cx1hotdx = False
            has_softmax = False
            has_softmaxdx = False
            for node in func.maker.fgraph.toposort():
                if node.op == crossentropy_softmax_1hot_with_bias_dx:
                    has_cx1hotdx = True
                if node.op == softmax_legacy:
                    has_softmax = True
                if node.op == softmax_grad_legacy:
                    has_softmaxdx = True

            assert has_cx1hotdx
            assert has_softmax
            assert not has_softmaxdx

        # Cases to test
        expressions = [
            a * aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * (-aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y]))),
            a * aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]),
            -a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * (-aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y])),
            a * aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * mean(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * (-mean(log(softmax(x)[aet.arange(y.shape[0]), y]))),
            a * mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            a * mean(-log(softmax(x))[aet.arange(y.shape[0]), y]),
            -a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
            a * (-mean(log(softmax(x))[aet.arange(y.shape[0]), y])),
            a * mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in expressions:
            fgraph = FunctionGraph([x, y, a], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 5 <= len(fgraph.toposort()) <= 10

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert softmax_legacy not in ops

            # Verify the gradient wrt x
            fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 3 <= len(fgraph.toposort()) <= 6

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

            # Verify the gradient when providing output gradient
            fgraph = FunctionGraph(
                [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            assert 6 <= len(fgraph.toposort()) <= 8

            ops = {node.op for node in fgraph.toposort()}
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops
Esempio n. 5
0
    def test_get_rid_of_advanced_indexing_version_of_xent(self):
        x = matrix("x")
        b = vector("b")
        y = lvector("y")

        # Basic case
        expressions = [
            aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x))[aet.arange(y.shape[0]), y]),
            aet_sum(-log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]
        for expr in expressions:

            fgraph = FunctionGraph([x, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 4
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            # Also verify the gradient wrt x
            fgraph = FunctionGraph([x, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

        # Test that a biased softmax is optimized correctly
        bias_expressions = [
            aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
            -aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
            aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in bias_expressions:
            fgraph = FunctionGraph([x, b, y], [expr, x])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2  # [big_op, sum]
            assert crossentropy_softmax_argmax_1hot_with_bias in ops

            fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 2
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_with_bias in ops
            assert softmax_grad_legacy not in ops

        # Test that using "mean" instead of sum works, too
        mean_expressions = [
            mean(-log(softmax(x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x))[aet.arange(y.shape[0]), y]),
            mean(-log(softmax(x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in mean_expressions:

            fgraph = FunctionGraph([x, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 6
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            fgraph = FunctionGraph([x, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 5
            # there's an extra dimshuffle in there
            # but I can't think of a good rule to get rid of it
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_legacy in ops
            assert softmax_grad_legacy not in ops

        mean_bias_expressions = [
            mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
            -mean(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
            mean(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
        ]

        for expr in mean_bias_expressions:

            fgraph = FunctionGraph([x, b, y], [expr])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 4
            assert crossentropy_softmax_argmax_1hot_with_bias in ops
            assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]

            fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
            optdb.query(OPT_FAST_RUN).optimize(fgraph)

            ops = [node.op for node in fgraph.toposort()]
            assert len(ops) == 5
            assert crossentropy_softmax_1hot_with_bias_dx in ops
            assert softmax_with_bias in ops
            assert softmax_grad_legacy not in ops
Esempio n. 6
0
    def grad(self, inp, grads):
        x, dy, scale, x_mean, x_invstd, epsilon = inp
        ddinputs, ddscale, ddbias = grads

        x_diff = x - x_mean
        mean_dy_x_diff = mean(dy * x_diff, axis=self.axes, keepdims=True)

        # compute gradients given each of the output gradients
        g_wrt_x = 0
        g_wrt_dy = 0
        g_wrt_scale = 0
        g_wrt_x_mean = 0
        g_wrt_x_invstd = 0

        if not isinstance(ddinputs.type, aesara.gradient.DisconnectedType):
            ccc = scale * (ddinputs - mean(ddinputs, axis=self.axes, keepdims=True))
            ddd = (x_invstd ** 3) * (
                ccc * mean(dy * x_diff, axis=self.axes, keepdims=True)
                + dy * mean(ccc * x_diff, axis=self.axes, keepdims=True)
            )

            g_wrt_x = g_wrt_x - ddd
            g_wrt_dy = g_wrt_dy + (
                (ccc * x_invstd)
                - (
                    (x_invstd ** 3)
                    * x_diff
                    * mean(ccc * x_diff, axis=self.axes, keepdims=True)
                )
            )

            eee = (dy * x_invstd) - ((x_invstd ** 3) * x_diff * mean_dy_x_diff)
            g_wrt_scale = g_wrt_scale + tt_sum(
                ddinputs * (eee - mean(eee, axis=self.axes, keepdims=True)),
                axis=self.axes,
                keepdims=True,
            )

            g_wrt_x_mean = g_wrt_x_mean + tt_sum(ddd, axis=self.axes, keepdims=True)
            g_wrt_x_invstd = g_wrt_x_invstd + tt_sum(
                ccc * (dy - 3 * (x_invstd ** 2) * x_diff * mean_dy_x_diff),
                axis=self.axes,
                keepdims=True,
            )

        if not isinstance(ddscale.type, aesara.gradient.DisconnectedType):
            g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
            g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
            g_wrt_x_mean = g_wrt_x_mean - (
                x_invstd * ddscale * tt_sum(dy, axis=self.axes, keepdims=True)
            )
            g_wrt_x_invstd = g_wrt_x_invstd + (
                ddscale * tt_sum(dy * x_diff, axis=self.axes, keepdims=True)
            )

        if not isinstance(ddbias.type, aesara.gradient.DisconnectedType):
            g_wrt_dy = g_wrt_dy + aet.fill(dy, ddbias)

        # depending on which output gradients are given,
        # some inputs should be disconnected
        results = [
            g_wrt_x,
            g_wrt_dy,
            g_wrt_scale,
            g_wrt_x_mean,
            g_wrt_x_invstd,
            aesara.gradient.DisconnectedType()(),
        ]
        return [
            aesara.gradient.DisconnectedType()() if (type(r) == int and r == 0) else r
            for r in results
        ]