Exemplo n.º 1
0
def test_batch_normalization_broadcastable():
    # check if the broadcastable pattern is preserved by the optimizations
    x, dy, scale, bias, mean, var = (scalar(n).dimshuffle(["x"] * 5)
                                     for n in ("x", "dy", "scale", "bias",
                                               "mean", "var"))

    # forward pass
    out_train, x_mean, x_invstd = batchnorm.batch_normalization_train(
        x, scale, bias, "spatial")
    out_test = batchnorm.batch_normalization_test(x, scale, bias, mean, var,
                                                  "spatial")
    # backward pass
    grads_train = aet.grad(None,
                           wrt=[x, scale, bias],
                           known_grads={out_train: dy})
    grads_test = aet.grad(None,
                          wrt=[x, scale, bias],
                          known_grads={out_test: dy})
    # compile
    f = aesara.function(
        [x, scale, bias, mean, var, dy],
        [out_train, x_mean, x_invstd, out_test] + grads_train + grads_test,
    )
    assert not any([
        isinstance(
            n.op,
            (
                batchnorm.AbstractBatchNormTrain,
                batchnorm.AbstractBatchNormInference,
                batchnorm.AbstractBatchNormTrainGrad,
            ),
        ) for n in f.maker.fgraph.toposort()
    ])
Exemplo n.º 2
0
def test_batch_normalization_train_without_running_averages():
    # compile and run batch_normalization_train without running averages
    utt.seed_rng()

    x, scale, bias, dy = (
        tensor4("x"),
        tensor4("scale"),
        tensor4("bias"),
        tensor4("dy"),
    )
    data_shape = (5, 10, 30, 25)
    param_shape = (1, 10, 30, 25)

    # forward pass
    out, x_mean, x_invstd = batchnorm.batch_normalization_train(
        x, scale, bias, "per-activation"
    )
    # backward pass
    grads = aet.grad(None, wrt=[x, scale, bias], known_grads={out: dy})
    # compile
    f = aesara.function([x, scale, bias, dy], [out, x_mean, x_invstd] + grads)
    # check if the abstract Ops have been replaced
    assert not any(
        [
            isinstance(
                n.op,
                (
                    batchnorm.AbstractBatchNormTrain,
                    batchnorm.AbstractBatchNormInference,
                    batchnorm.AbstractBatchNormTrainGrad,
                ),
            )
            for n in f.maker.fgraph.toposort()
        ]
    )
    # run
    X = 4 + 3 * np.random.randn(*data_shape).astype(aesara.config.floatX)
    Dy = -1 + 2 * np.random.randn(*data_shape).astype(aesara.config.floatX)
    Scale = np.random.randn(*param_shape).astype(aesara.config.floatX)
    Bias = np.random.randn(*param_shape).astype(aesara.config.floatX)
    f(X, Scale, Bias, Dy)
Exemplo n.º 3
0
def test_batch_normalization_train_broadcast():
    for axes in ("per-activation", "spatial", (1, 2, 3, 4)):
        for vartype in (tensor5, tensor4, tensor3, matrix, vector):
            x = vartype("x")
            ndim = x.ndim
            eps = 5e-3  # some non-standard value to test if it's used
            running_average_factor = 0.3

            # remove non-existing axes
            if isinstance(axes, tuple):
                axes = tuple(i for i in axes if i < ndim)
            if len(axes) == 0:
                continue

            # convert axes to explicit list
            if axes == "per-activation":
                axes2 = (0, )
            elif axes == "spatial":
                axes2 = (0, ) + tuple(range(2, ndim))
            else:
                axes2 = axes

            # compute axes for parameter tensors
            non_bc_axes = tuple(i for i in range(ndim) if i not in axes2)
            params_dimshuffle = ["x"] * ndim
            for i, axis in enumerate(non_bc_axes):
                params_dimshuffle[axis] = i

            # construct non-broadcasted parameter variables
            param_type = TensorType(x.dtype, (False, ) * len(non_bc_axes))
            scale, bias, running_mean, running_var = (param_type(n)
                                                      for n in ("scale",
                                                                "bias",
                                                                "running_mean",
                                                                "running_var"))

            # broadcast parameter variables
            scale_bc = scale.dimshuffle(params_dimshuffle)
            bias_bc = bias.dimshuffle(params_dimshuffle)
            running_mean_bc = running_mean.dimshuffle(params_dimshuffle)
            running_var_bc = running_var.dimshuffle(params_dimshuffle)

            # batch_normalization_train with original, non-broadcasted variables
            train_non_bc = batchnorm.batch_normalization_train(
                x,
                scale,
                bias,
                axes,
                eps,
                running_average_factor,
                running_mean,
                running_var,
            )
            # batch_normalization_train with broadcasted variables
            train_bc = batchnorm.batch_normalization_train(
                x,
                scale_bc,
                bias_bc,
                axes,
                eps,
                running_average_factor,
                running_mean_bc,
                running_var_bc,
            )
            train_bc = tuple([train_bc[0]] +
                             [r.dimshuffle(non_bc_axes)
                              for r in train_bc[1:]]  # out
                             )

            # batch_normalization_test with original, non-broadcasted variables
            test_non_bc = batchnorm.batch_normalization_test(
                x, scale, bias, running_mean, running_var, axes, eps)
            # batch_normalization_test with broadcasted variables
            test_bc = batchnorm.batch_normalization_test(
                x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes,
                eps)

            # subtract the results of the non-broadcasted and broadcasted calls
            results_non_bc = train_non_bc + (test_non_bc, )
            results_bc = train_bc + (test_bc, )
            results = [
                abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc)
            ]

            # compile to compute all differences
            f = aesara.function([x, scale, bias, running_mean, running_var],
                                aet_sum(sum(results)))

            # the paired ops are exactly the same, so the optimizer should have
            # collapsed the sum of differences to a constant zero
            nodes = f.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(nodes) == 1
                assert isinstance(nodes[0].op, aesara.compile.DeepCopyOp)
            inputs = [
                np.asarray(np.random.random(((4, ) * n)), x.dtype) for n in [
                    x.ndim,
                    scale.ndim,
                    bias.ndim,
                    running_mean.ndim,
                    running_var.ndim,
                ]
            ]
            assert 0.0 == f(*inputs)
Exemplo n.º 4
0
def test_batch_normalization_train():

    for axes in ("per-activation", "spatial", (1, 2, 3, 4)):
        for vartype in (tensor5, tensor3, vector):
            x, scale, bias, running_mean, running_var = (vartype(n) for n in (
                "x", "scale", "bias", "running_mean", "running_var"))
            ndim = x.ndim
            eps = 5e-3  # some non-standard value to test if it's used
            running_average_factor = 0.3

            # remove non-existing axes
            if isinstance(axes, tuple):
                axes = tuple(i for i in axes if i < ndim)
            if len(axes) == 0:
                continue

            # forward pass
            (
                out,
                x_mean,
                x_invstd,
                out_running_mean,
                out_running_var,
            ) = batchnorm.batch_normalization_train(
                x,
                scale,
                bias,
                axes,
                eps,
                running_average_factor,
                running_mean,
                running_var,
            )
            # reference forward pass
            if axes == "per-activation":
                axes2 = (0, )
            elif axes == "spatial":
                axes2 = (0, ) + tuple(range(2, ndim))
            else:
                axes2 = axes
            x_mean2 = x.mean(axis=axes2, keepdims=True)
            x_var2 = x.var(axis=axes2, keepdims=True)
            x_invstd2 = aet.reciprocal(aet.sqrt(x_var2 + eps))
            scale2 = aet.addbroadcast(scale, *axes2)
            bias2 = aet.addbroadcast(bias, *axes2)
            out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
            m = aet.cast(
                aet.prod(x.shape) / aet.prod(scale.shape),
                aesara.config.floatX)
            out_running_mean2 = (running_mean * (1 - running_average_factor) +
                                 x_mean2 * running_average_factor)
            out_running_var2 = (running_var * (1 - running_average_factor) +
                                (m /
                                 (m - 1)) * x_var2 * running_average_factor)
            # backward pass
            dy = vartype("dy")
            grads = aet.grad(None, wrt=[x, scale, bias], known_grads={out: dy})
            # reference backward pass
            grads2 = aet.grad(None,
                              wrt=[x, scale, bias],
                              known_grads={out2: dy})
            # second-order backward pass
            dx = vartype("dinputs")
            dscale = vartype("dscale")
            dbias = vartype("dbias")
            grad_grads = aet.grad(
                None,
                wrt=[x, dy, scale],
                known_grads=OrderedDict({
                    grads[0]: dx,
                    grads[1]: dscale,
                    grads[2]: dbias
                }),
                consider_constant=[
                    x,
                    dy,
                    scale,
                    bias,
                    x_mean,
                    x_invstd,
                    running_mean,
                    running_var,
                ],
                return_disconnected="zero",
            )
            # reference second-order backward pass
            grad_grads2 = aet.grad(
                None,
                wrt=[x, dy, scale],
                known_grads=OrderedDict({
                    grads2[0]: dx,
                    grads2[1]: dscale,
                    grads2[2]: dbias
                }),
                consider_constant=[
                    x,
                    dy,
                    scale,
                    bias,
                    x_mean2,
                    x_var2,
                    running_mean,
                    running_var,
                ],
                return_disconnected="zero",
            )
            # compile
            f = aesara.function(
                [
                    x, scale, bias, running_mean, running_var, dy, dx, dscale,
                    dbias
                ],
                [
                    out,
                    x_mean,
                    x_invstd,
                    out_running_mean,
                    out_running_var,
                    out2,
                    x_mean2,
                    x_invstd2,
                    out_running_mean2,
                    out_running_var2,
                ] + grads + grads2 + grad_grads + grad_grads2,
            )
            # check if the abstract Ops have been replaced
            assert not any([
                isinstance(
                    n.op,
                    (
                        batchnorm.AbstractBatchNormTrain,
                        batchnorm.AbstractBatchNormInference,
                        batchnorm.AbstractBatchNormTrainGrad,
                    ),
                ) for n in f.maker.fgraph.toposort()
            ])
            # run
            for data_shape in ((5, 10, 30, 40, 10), (4, 3, 1, 1, 1), (2, 3, 5,
                                                                      5, 5)):
                data_shape = data_shape[:ndim]
                param_shape = tuple(1 if d in axes2 else s
                                    for d, s in enumerate(data_shape))

                rng = np.random.default_rng(1234)

                X = 4 + 3 * rng.random(data_shape).astype(aesara.config.floatX)
                Dy = -1 + 2 * rng.random(data_shape).astype(
                    aesara.config.floatX)
                Scale = rng.random(param_shape).astype(aesara.config.floatX)
                Bias = rng.random(param_shape).astype(aesara.config.floatX)
                Running_mean = rng.random(param_shape).astype(
                    aesara.config.floatX)
                Running_var = rng.random(param_shape).astype(
                    aesara.config.floatX)
                Dx = 4 + 3 * rng.random(data_shape).astype(
                    aesara.config.floatX)
                Dscale = -1 + 2 * rng.random(param_shape).astype(
                    aesara.config.floatX)
                Dbias = rng.random(param_shape).astype(aesara.config.floatX)

                outputs = f(X, Scale, Bias, Running_mean, Running_var, Dy, Dx,
                            Dscale, Dbias)
                # compare outputs
                utt.assert_allclose(outputs[0], outputs[0 + 5])  # out
                utt.assert_allclose(outputs[1], outputs[1 + 5])  # mean
                utt.assert_allclose(outputs[2], outputs[2 + 5])  # invstd
                utt.assert_allclose(outputs[3], outputs[3 + 5])  # running_mean
                utt.assert_allclose(np.nan_to_num(outputs[4]),
                                    np.nan_to_num(outputs[4 +
                                                          5]))  # running_var
                # compare gradients
                utt.assert_allclose(outputs[10], outputs[10 + 3],
                                    atol=1e-4)  # dx
                utt.assert_allclose(outputs[11],
                                    outputs[11 + 3],
                                    rtol=2e-4,
                                    atol=1e-4)  # dscale
                utt.assert_allclose(outputs[12], outputs[12 + 3])  # dbias
                # compare second-order gradients
                utt.assert_allclose(outputs[16], outputs[16 + 3],
                                    atol=1e-4)  # ddx
                utt.assert_allclose(outputs[17], outputs[17 + 3])  # ddy
                utt.assert_allclose(outputs[18],
                                    outputs[18 + 3],
                                    rtol=3e-4,
                                    atol=1e-4)  # ddscale