def test_batch_normalization_broadcastable(): # check if the broadcastable pattern is preserved by the optimizations x, dy, scale, bias, mean, var = (scalar(n).dimshuffle(["x"] * 5) for n in ("x", "dy", "scale", "bias", "mean", "var")) # forward pass out_train, x_mean, x_invstd = batchnorm.batch_normalization_train( x, scale, bias, "spatial") out_test = batchnorm.batch_normalization_test(x, scale, bias, mean, var, "spatial") # backward pass grads_train = aet.grad(None, wrt=[x, scale, bias], known_grads={out_train: dy}) grads_test = aet.grad(None, wrt=[x, scale, bias], known_grads={out_test: dy}) # compile f = aesara.function( [x, scale, bias, mean, var, dy], [out_train, x_mean, x_invstd, out_test] + grads_train + grads_test, ) assert not any([ isinstance( n.op, ( batchnorm.AbstractBatchNormTrain, batchnorm.AbstractBatchNormInference, batchnorm.AbstractBatchNormTrainGrad, ), ) for n in f.maker.fgraph.toposort() ])
def test_batch_normalization_test(): for axes in ("per-activation", "spatial", (1, 2, 3, 4)): for vartype in (tensor5, tensor3, vector): x, scale, bias, mean, var = (vartype(n) for n in ("x", "scale", "bias", "mean", "var")) ndim = x.ndim eps = 5e-3 # some non-standard value to test if it's used # remove non-existing axes if isinstance(axes, tuple): axes = tuple(i for i in axes if i < ndim) if len(axes) == 0: continue # forward pass out = batchnorm.batch_normalization_test(x, scale, bias, mean, var, axes, eps) # reference forward pass if axes == "per-activation": axes2 = (0, ) elif axes == "spatial": axes2 = (0, ) + tuple(range(2, ndim)) else: axes2 = axes scale2, bias2, mean2, var2 = (aet.addbroadcast(t, *axes2) for t in (scale, bias, mean, var)) out2 = (x - mean2) * (scale2 / aet.sqrt(var2 + eps)) + bias2 # backward pass dy = vartype("dy") grads = aet.grad(None, wrt=[x, scale, bias, mean, var], known_grads={out: dy}) # reference backward pass grads2 = aet.grad(None, wrt=[x, scale, bias, mean, var], known_grads={out2: dy}) # compile f = aesara.function([x, scale, bias, mean, var, dy], [out, out2] + grads + grads2) # check if the abstract Ops have been replaced assert not any([ isinstance( n.op, ( batchnorm.AbstractBatchNormTrain, batchnorm.AbstractBatchNormInference, batchnorm.AbstractBatchNormTrainGrad, ), ) for n in f.maker.fgraph.toposort() ]) # run for data_shape in ((10, 20, 30, 40, 10), (4, 3, 1, 1, 1), (1, 1, 5, 5, 5)): data_shape = data_shape[:ndim] param_shape = tuple(1 if d in axes2 else s for d, s in enumerate(data_shape)) rng = np.random.default_rng(1234) X = 4 + 3 * rng.random(data_shape).astype(aesara.config.floatX) Dy = -1 + 2 * rng.random(data_shape).astype( aesara.config.floatX) Scale = rng.random(param_shape).astype(aesara.config.floatX) Bias = rng.random(param_shape).astype(aesara.config.floatX) Mean = rng.random(param_shape).astype(aesara.config.floatX) Var = rng.random(param_shape).astype(aesara.config.floatX) outputs = f(X, Scale, Bias, Mean, Var, Dy) # compare outputs utt.assert_allclose(outputs[0], outputs[1]) # out # compare gradients utt.assert_allclose(outputs[2], outputs[2 + 5], atol=4e-5) # dx utt.assert_allclose(outputs[3], outputs[3 + 5], atol=4e-5) # dscale utt.assert_allclose(outputs[4], outputs[4 + 5]) # dbias utt.assert_allclose(outputs[5], outputs[5 + 5]) # dmean utt.assert_allclose(outputs[6], outputs[6 + 5], rtol=2e-3, atol=4e-5) # dvar
def test_batch_normalization_train_broadcast(): for axes in ("per-activation", "spatial", (1, 2, 3, 4)): for vartype in (tensor5, tensor4, tensor3, matrix, vector): x = vartype("x") ndim = x.ndim eps = 5e-3 # some non-standard value to test if it's used running_average_factor = 0.3 # remove non-existing axes if isinstance(axes, tuple): axes = tuple(i for i in axes if i < ndim) if len(axes) == 0: continue # convert axes to explicit list if axes == "per-activation": axes2 = (0, ) elif axes == "spatial": axes2 = (0, ) + tuple(range(2, ndim)) else: axes2 = axes # compute axes for parameter tensors non_bc_axes = tuple(i for i in range(ndim) if i not in axes2) params_dimshuffle = ["x"] * ndim for i, axis in enumerate(non_bc_axes): params_dimshuffle[axis] = i # construct non-broadcasted parameter variables param_type = TensorType(x.dtype, (False, ) * len(non_bc_axes)) scale, bias, running_mean, running_var = (param_type(n) for n in ("scale", "bias", "running_mean", "running_var")) # broadcast parameter variables scale_bc = scale.dimshuffle(params_dimshuffle) bias_bc = bias.dimshuffle(params_dimshuffle) running_mean_bc = running_mean.dimshuffle(params_dimshuffle) running_var_bc = running_var.dimshuffle(params_dimshuffle) # batch_normalization_train with original, non-broadcasted variables train_non_bc = batchnorm.batch_normalization_train( x, scale, bias, axes, eps, running_average_factor, running_mean, running_var, ) # batch_normalization_train with broadcasted variables train_bc = batchnorm.batch_normalization_train( x, scale_bc, bias_bc, axes, eps, running_average_factor, running_mean_bc, running_var_bc, ) train_bc = tuple([train_bc[0]] + [r.dimshuffle(non_bc_axes) for r in train_bc[1:]] # out ) # batch_normalization_test with original, non-broadcasted variables test_non_bc = batchnorm.batch_normalization_test( x, scale, bias, running_mean, running_var, axes, eps) # batch_normalization_test with broadcasted variables test_bc = batchnorm.batch_normalization_test( x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes, eps) # subtract the results of the non-broadcasted and broadcasted calls results_non_bc = train_non_bc + (test_non_bc, ) results_bc = train_bc + (test_bc, ) results = [ abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc) ] # compile to compute all differences f = aesara.function([x, scale, bias, running_mean, running_var], aet_sum(sum(results))) # the paired ops are exactly the same, so the optimizer should have # collapsed the sum of differences to a constant zero nodes = f.maker.fgraph.toposort() if aesara.config.mode != "FAST_COMPILE": assert len(nodes) == 1 assert isinstance(nodes[0].op, aesara.compile.DeepCopyOp) inputs = [ np.asarray(np.random.random(((4, ) * n)), x.dtype) for n in [ x.ndim, scale.ndim, bias.ndim, running_mean.ndim, running_var.ndim, ] ] assert 0.0 == f(*inputs)