Exemple #1
0
def test_batch_normalization_broadcastable():
    # check if the broadcastable pattern is preserved by the optimizations
    x, dy, scale, bias, mean, var = (scalar(n).dimshuffle(["x"] * 5)
                                     for n in ("x", "dy", "scale", "bias",
                                               "mean", "var"))

    # forward pass
    out_train, x_mean, x_invstd = batchnorm.batch_normalization_train(
        x, scale, bias, "spatial")
    out_test = batchnorm.batch_normalization_test(x, scale, bias, mean, var,
                                                  "spatial")
    # backward pass
    grads_train = aet.grad(None,
                           wrt=[x, scale, bias],
                           known_grads={out_train: dy})
    grads_test = aet.grad(None,
                          wrt=[x, scale, bias],
                          known_grads={out_test: dy})
    # compile
    f = aesara.function(
        [x, scale, bias, mean, var, dy],
        [out_train, x_mean, x_invstd, out_test] + grads_train + grads_test,
    )
    assert not any([
        isinstance(
            n.op,
            (
                batchnorm.AbstractBatchNormTrain,
                batchnorm.AbstractBatchNormInference,
                batchnorm.AbstractBatchNormTrainGrad,
            ),
        ) for n in f.maker.fgraph.toposort()
    ])
Exemple #2
0
def test_batch_normalization_test():
    for axes in ("per-activation", "spatial", (1, 2, 3, 4)):
        for vartype in (tensor5, tensor3, vector):
            x, scale, bias, mean, var = (vartype(n)
                                         for n in ("x", "scale", "bias",
                                                   "mean", "var"))
            ndim = x.ndim
            eps = 5e-3  # some non-standard value to test if it's used

            # remove non-existing axes
            if isinstance(axes, tuple):
                axes = tuple(i for i in axes if i < ndim)
            if len(axes) == 0:
                continue

            # forward pass
            out = batchnorm.batch_normalization_test(x, scale, bias, mean, var,
                                                     axes, eps)
            # reference forward pass
            if axes == "per-activation":
                axes2 = (0, )
            elif axes == "spatial":
                axes2 = (0, ) + tuple(range(2, ndim))
            else:
                axes2 = axes
            scale2, bias2, mean2, var2 = (aet.addbroadcast(t, *axes2)
                                          for t in (scale, bias, mean, var))
            out2 = (x - mean2) * (scale2 / aet.sqrt(var2 + eps)) + bias2
            # backward pass
            dy = vartype("dy")
            grads = aet.grad(None,
                             wrt=[x, scale, bias, mean, var],
                             known_grads={out: dy})
            # reference backward pass
            grads2 = aet.grad(None,
                              wrt=[x, scale, bias, mean, var],
                              known_grads={out2: dy})
            # compile
            f = aesara.function([x, scale, bias, mean, var, dy],
                                [out, out2] + grads + grads2)
            # check if the abstract Ops have been replaced
            assert not any([
                isinstance(
                    n.op,
                    (
                        batchnorm.AbstractBatchNormTrain,
                        batchnorm.AbstractBatchNormInference,
                        batchnorm.AbstractBatchNormTrainGrad,
                    ),
                ) for n in f.maker.fgraph.toposort()
            ])
            # run
            for data_shape in ((10, 20, 30, 40, 10), (4, 3, 1, 1, 1), (1, 1, 5,
                                                                       5, 5)):
                data_shape = data_shape[:ndim]
                param_shape = tuple(1 if d in axes2 else s
                                    for d, s in enumerate(data_shape))
                rng = np.random.default_rng(1234)
                X = 4 + 3 * rng.random(data_shape).astype(aesara.config.floatX)
                Dy = -1 + 2 * rng.random(data_shape).astype(
                    aesara.config.floatX)
                Scale = rng.random(param_shape).astype(aesara.config.floatX)
                Bias = rng.random(param_shape).astype(aesara.config.floatX)
                Mean = rng.random(param_shape).astype(aesara.config.floatX)
                Var = rng.random(param_shape).astype(aesara.config.floatX)
                outputs = f(X, Scale, Bias, Mean, Var, Dy)
                # compare outputs
                utt.assert_allclose(outputs[0], outputs[1])  # out
                # compare gradients
                utt.assert_allclose(outputs[2], outputs[2 + 5],
                                    atol=4e-5)  # dx
                utt.assert_allclose(outputs[3], outputs[3 + 5],
                                    atol=4e-5)  # dscale
                utt.assert_allclose(outputs[4], outputs[4 + 5])  # dbias
                utt.assert_allclose(outputs[5], outputs[5 + 5])  # dmean
                utt.assert_allclose(outputs[6],
                                    outputs[6 + 5],
                                    rtol=2e-3,
                                    atol=4e-5)  # dvar
Exemple #3
0
def test_batch_normalization_train_broadcast():
    for axes in ("per-activation", "spatial", (1, 2, 3, 4)):
        for vartype in (tensor5, tensor4, tensor3, matrix, vector):
            x = vartype("x")
            ndim = x.ndim
            eps = 5e-3  # some non-standard value to test if it's used
            running_average_factor = 0.3

            # remove non-existing axes
            if isinstance(axes, tuple):
                axes = tuple(i for i in axes if i < ndim)
            if len(axes) == 0:
                continue

            # convert axes to explicit list
            if axes == "per-activation":
                axes2 = (0, )
            elif axes == "spatial":
                axes2 = (0, ) + tuple(range(2, ndim))
            else:
                axes2 = axes

            # compute axes for parameter tensors
            non_bc_axes = tuple(i for i in range(ndim) if i not in axes2)
            params_dimshuffle = ["x"] * ndim
            for i, axis in enumerate(non_bc_axes):
                params_dimshuffle[axis] = i

            # construct non-broadcasted parameter variables
            param_type = TensorType(x.dtype, (False, ) * len(non_bc_axes))
            scale, bias, running_mean, running_var = (param_type(n)
                                                      for n in ("scale",
                                                                "bias",
                                                                "running_mean",
                                                                "running_var"))

            # broadcast parameter variables
            scale_bc = scale.dimshuffle(params_dimshuffle)
            bias_bc = bias.dimshuffle(params_dimshuffle)
            running_mean_bc = running_mean.dimshuffle(params_dimshuffle)
            running_var_bc = running_var.dimshuffle(params_dimshuffle)

            # batch_normalization_train with original, non-broadcasted variables
            train_non_bc = batchnorm.batch_normalization_train(
                x,
                scale,
                bias,
                axes,
                eps,
                running_average_factor,
                running_mean,
                running_var,
            )
            # batch_normalization_train with broadcasted variables
            train_bc = batchnorm.batch_normalization_train(
                x,
                scale_bc,
                bias_bc,
                axes,
                eps,
                running_average_factor,
                running_mean_bc,
                running_var_bc,
            )
            train_bc = tuple([train_bc[0]] +
                             [r.dimshuffle(non_bc_axes)
                              for r in train_bc[1:]]  # out
                             )

            # batch_normalization_test with original, non-broadcasted variables
            test_non_bc = batchnorm.batch_normalization_test(
                x, scale, bias, running_mean, running_var, axes, eps)
            # batch_normalization_test with broadcasted variables
            test_bc = batchnorm.batch_normalization_test(
                x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes,
                eps)

            # subtract the results of the non-broadcasted and broadcasted calls
            results_non_bc = train_non_bc + (test_non_bc, )
            results_bc = train_bc + (test_bc, )
            results = [
                abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc)
            ]

            # compile to compute all differences
            f = aesara.function([x, scale, bias, running_mean, running_var],
                                aet_sum(sum(results)))

            # the paired ops are exactly the same, so the optimizer should have
            # collapsed the sum of differences to a constant zero
            nodes = f.maker.fgraph.toposort()
            if aesara.config.mode != "FAST_COMPILE":
                assert len(nodes) == 1
                assert isinstance(nodes[0].op, aesara.compile.DeepCopyOp)
            inputs = [
                np.asarray(np.random.random(((4, ) * n)), x.dtype) for n in [
                    x.ndim,
                    scale.ndim,
                    bias.ndim,
                    running_mean.ndim,
                    running_var.ndim,
                ]
            ]
            assert 0.0 == f(*inputs)