def test_batch_normalization_broadcastable(): # check if the broadcastable pattern is preserved by the optimizations x, dy, scale, bias, mean, var = (scalar(n).dimshuffle(["x"] * 5) for n in ("x", "dy", "scale", "bias", "mean", "var")) # forward pass out_train, x_mean, x_invstd = batchnorm.batch_normalization_train( x, scale, bias, "spatial") out_test = batchnorm.batch_normalization_test(x, scale, bias, mean, var, "spatial") # backward pass grads_train = aet.grad(None, wrt=[x, scale, bias], known_grads={out_train: dy}) grads_test = aet.grad(None, wrt=[x, scale, bias], known_grads={out_test: dy}) # compile f = aesara.function( [x, scale, bias, mean, var, dy], [out_train, x_mean, x_invstd, out_test] + grads_train + grads_test, ) assert not any([ isinstance( n.op, ( batchnorm.AbstractBatchNormTrain, batchnorm.AbstractBatchNormInference, batchnorm.AbstractBatchNormTrainGrad, ), ) for n in f.maker.fgraph.toposort() ])
def test_batch_normalization_train_without_running_averages(): # compile and run batch_normalization_train without running averages utt.seed_rng() x, scale, bias, dy = ( tensor4("x"), tensor4("scale"), tensor4("bias"), tensor4("dy"), ) data_shape = (5, 10, 30, 25) param_shape = (1, 10, 30, 25) # forward pass out, x_mean, x_invstd = batchnorm.batch_normalization_train( x, scale, bias, "per-activation" ) # backward pass grads = aet.grad(None, wrt=[x, scale, bias], known_grads={out: dy}) # compile f = aesara.function([x, scale, bias, dy], [out, x_mean, x_invstd] + grads) # check if the abstract Ops have been replaced assert not any( [ isinstance( n.op, ( batchnorm.AbstractBatchNormTrain, batchnorm.AbstractBatchNormInference, batchnorm.AbstractBatchNormTrainGrad, ), ) for n in f.maker.fgraph.toposort() ] ) # run X = 4 + 3 * np.random.randn(*data_shape).astype(aesara.config.floatX) Dy = -1 + 2 * np.random.randn(*data_shape).astype(aesara.config.floatX) Scale = np.random.randn(*param_shape).astype(aesara.config.floatX) Bias = np.random.randn(*param_shape).astype(aesara.config.floatX) f(X, Scale, Bias, Dy)
def test_batch_normalization_train_broadcast(): for axes in ("per-activation", "spatial", (1, 2, 3, 4)): for vartype in (tensor5, tensor4, tensor3, matrix, vector): x = vartype("x") ndim = x.ndim eps = 5e-3 # some non-standard value to test if it's used running_average_factor = 0.3 # remove non-existing axes if isinstance(axes, tuple): axes = tuple(i for i in axes if i < ndim) if len(axes) == 0: continue # convert axes to explicit list if axes == "per-activation": axes2 = (0, ) elif axes == "spatial": axes2 = (0, ) + tuple(range(2, ndim)) else: axes2 = axes # compute axes for parameter tensors non_bc_axes = tuple(i for i in range(ndim) if i not in axes2) params_dimshuffle = ["x"] * ndim for i, axis in enumerate(non_bc_axes): params_dimshuffle[axis] = i # construct non-broadcasted parameter variables param_type = TensorType(x.dtype, (False, ) * len(non_bc_axes)) scale, bias, running_mean, running_var = (param_type(n) for n in ("scale", "bias", "running_mean", "running_var")) # broadcast parameter variables scale_bc = scale.dimshuffle(params_dimshuffle) bias_bc = bias.dimshuffle(params_dimshuffle) running_mean_bc = running_mean.dimshuffle(params_dimshuffle) running_var_bc = running_var.dimshuffle(params_dimshuffle) # batch_normalization_train with original, non-broadcasted variables train_non_bc = batchnorm.batch_normalization_train( x, scale, bias, axes, eps, running_average_factor, running_mean, running_var, ) # batch_normalization_train with broadcasted variables train_bc = batchnorm.batch_normalization_train( x, scale_bc, bias_bc, axes, eps, running_average_factor, running_mean_bc, running_var_bc, ) train_bc = tuple([train_bc[0]] + [r.dimshuffle(non_bc_axes) for r in train_bc[1:]] # out ) # batch_normalization_test with original, non-broadcasted variables test_non_bc = batchnorm.batch_normalization_test( x, scale, bias, running_mean, running_var, axes, eps) # batch_normalization_test with broadcasted variables test_bc = batchnorm.batch_normalization_test( x, scale_bc, bias_bc, running_mean_bc, running_var_bc, axes, eps) # subtract the results of the non-broadcasted and broadcasted calls results_non_bc = train_non_bc + (test_non_bc, ) results_bc = train_bc + (test_bc, ) results = [ abs(r - r_bc) for (r, r_bc) in zip(results_non_bc, results_bc) ] # compile to compute all differences f = aesara.function([x, scale, bias, running_mean, running_var], aet_sum(sum(results))) # the paired ops are exactly the same, so the optimizer should have # collapsed the sum of differences to a constant zero nodes = f.maker.fgraph.toposort() if aesara.config.mode != "FAST_COMPILE": assert len(nodes) == 1 assert isinstance(nodes[0].op, aesara.compile.DeepCopyOp) inputs = [ np.asarray(np.random.random(((4, ) * n)), x.dtype) for n in [ x.ndim, scale.ndim, bias.ndim, running_mean.ndim, running_var.ndim, ] ] assert 0.0 == f(*inputs)
def test_batch_normalization_train(): for axes in ("per-activation", "spatial", (1, 2, 3, 4)): for vartype in (tensor5, tensor3, vector): x, scale, bias, running_mean, running_var = (vartype(n) for n in ( "x", "scale", "bias", "running_mean", "running_var")) ndim = x.ndim eps = 5e-3 # some non-standard value to test if it's used running_average_factor = 0.3 # remove non-existing axes if isinstance(axes, tuple): axes = tuple(i for i in axes if i < ndim) if len(axes) == 0: continue # forward pass ( out, x_mean, x_invstd, out_running_mean, out_running_var, ) = batchnorm.batch_normalization_train( x, scale, bias, axes, eps, running_average_factor, running_mean, running_var, ) # reference forward pass if axes == "per-activation": axes2 = (0, ) elif axes == "spatial": axes2 = (0, ) + tuple(range(2, ndim)) else: axes2 = axes x_mean2 = x.mean(axis=axes2, keepdims=True) x_var2 = x.var(axis=axes2, keepdims=True) x_invstd2 = aet.reciprocal(aet.sqrt(x_var2 + eps)) scale2 = aet.addbroadcast(scale, *axes2) bias2 = aet.addbroadcast(bias, *axes2) out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2 m = aet.cast( aet.prod(x.shape) / aet.prod(scale.shape), aesara.config.floatX) out_running_mean2 = (running_mean * (1 - running_average_factor) + x_mean2 * running_average_factor) out_running_var2 = (running_var * (1 - running_average_factor) + (m / (m - 1)) * x_var2 * running_average_factor) # backward pass dy = vartype("dy") grads = aet.grad(None, wrt=[x, scale, bias], known_grads={out: dy}) # reference backward pass grads2 = aet.grad(None, wrt=[x, scale, bias], known_grads={out2: dy}) # second-order backward pass dx = vartype("dinputs") dscale = vartype("dscale") dbias = vartype("dbias") grad_grads = aet.grad( None, wrt=[x, dy, scale], known_grads=OrderedDict({ grads[0]: dx, grads[1]: dscale, grads[2]: dbias }), consider_constant=[ x, dy, scale, bias, x_mean, x_invstd, running_mean, running_var, ], return_disconnected="zero", ) # reference second-order backward pass grad_grads2 = aet.grad( None, wrt=[x, dy, scale], known_grads=OrderedDict({ grads2[0]: dx, grads2[1]: dscale, grads2[2]: dbias }), consider_constant=[ x, dy, scale, bias, x_mean2, x_var2, running_mean, running_var, ], return_disconnected="zero", ) # compile f = aesara.function( [ x, scale, bias, running_mean, running_var, dy, dx, dscale, dbias ], [ out, x_mean, x_invstd, out_running_mean, out_running_var, out2, x_mean2, x_invstd2, out_running_mean2, out_running_var2, ] + grads + grads2 + grad_grads + grad_grads2, ) # check if the abstract Ops have been replaced assert not any([ isinstance( n.op, ( batchnorm.AbstractBatchNormTrain, batchnorm.AbstractBatchNormInference, batchnorm.AbstractBatchNormTrainGrad, ), ) for n in f.maker.fgraph.toposort() ]) # run for data_shape in ((5, 10, 30, 40, 10), (4, 3, 1, 1, 1), (2, 3, 5, 5, 5)): data_shape = data_shape[:ndim] param_shape = tuple(1 if d in axes2 else s for d, s in enumerate(data_shape)) rng = np.random.default_rng(1234) X = 4 + 3 * rng.random(data_shape).astype(aesara.config.floatX) Dy = -1 + 2 * rng.random(data_shape).astype( aesara.config.floatX) Scale = rng.random(param_shape).astype(aesara.config.floatX) Bias = rng.random(param_shape).astype(aesara.config.floatX) Running_mean = rng.random(param_shape).astype( aesara.config.floatX) Running_var = rng.random(param_shape).astype( aesara.config.floatX) Dx = 4 + 3 * rng.random(data_shape).astype( aesara.config.floatX) Dscale = -1 + 2 * rng.random(param_shape).astype( aesara.config.floatX) Dbias = rng.random(param_shape).astype(aesara.config.floatX) outputs = f(X, Scale, Bias, Running_mean, Running_var, Dy, Dx, Dscale, Dbias) # compare outputs utt.assert_allclose(outputs[0], outputs[0 + 5]) # out utt.assert_allclose(outputs[1], outputs[1 + 5]) # mean utt.assert_allclose(outputs[2], outputs[2 + 5]) # invstd utt.assert_allclose(outputs[3], outputs[3 + 5]) # running_mean utt.assert_allclose(np.nan_to_num(outputs[4]), np.nan_to_num(outputs[4 + 5])) # running_var # compare gradients utt.assert_allclose(outputs[10], outputs[10 + 3], atol=1e-4) # dx utt.assert_allclose(outputs[11], outputs[11 + 3], rtol=2e-4, atol=1e-4) # dscale utt.assert_allclose(outputs[12], outputs[12 + 3]) # dbias # compare second-order gradients utt.assert_allclose(outputs[16], outputs[16 + 3], atol=1e-4) # ddx utt.assert_allclose(outputs[17], outputs[17 + 3]) # ddy utt.assert_allclose(outputs[18], outputs[18 + 3], rtol=3e-4, atol=1e-4) # ddscale