def test_batch_normalization(): def bn_ref(x, G, B, M, V): n = (x - M) / V return n * G + B np.random.seed(1234) X = 1 + np.random.random([10, 20]).astype("float32") B = 1 + np.random.random([20]).astype("float32") G = 1 + np.random.random([20]).astype("float32") M = 1 + np.random.random([20]).astype("float32") V = 1 + np.random.random([20]).astype("float32") x = matrix("x") b = vector("b") g = vector("g") m = vector("m") v = vector("v") bn_ref_op = bn_ref(x, g, b, m, v) f_ref = aesara.function([x, g, b, m, v], [bn_ref_op]) res_ref = f_ref(X, G, B, M, V) for mode in ["low_mem", "high_mem"]: bn_op = batchnorm.batch_normalization(x, g, b, m, v, mode=mode) f = aesara.function([x, g, b, m, v], [bn_op]) res = f(X, G, B, M, V) utt.assert_allclose(res_ref, res) def bn_f(inputs, gamma, beta, mean, std): return batchnorm.batch_normalization( inputs, gamma, beta, mean, std, mode=mode ) utt.verify_grad(bn_f, [X, G, B, M, V]) bn_ref_op = bn_ref( x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True) ) f_ref = aesara.function([x, b, g], [bn_ref_op]) res_ref = f_ref(X, G, B) for mode in ["low_mem", "high_mem"]: bn_op = batchnorm.batch_normalization( x, g, b, x.mean(axis=0, keepdims=True), x.std(axis=0, keepdims=True), mode=mode, ) f = aesara.function([x, b, g], [bn_op]) res = f(X, G, B) utt.assert_allclose(res_ref, res) def bn_f(inputs, gamma, beta, mean, std): return batchnorm.batch_normalization( inputs, gamma, beta, mean, std, mode=mode ) utt.verify_grad( bn_f, [X, G, B, X.mean(axis=0)[np.newaxis], X.std(axis=0)[np.newaxis]] )
def bn_f(inputs, gamma, beta, mean, std): return batchnorm.batch_normalization(inputs, gamma, beta, mean, std, mode=mode)
def test_BNComposite(): with config.change_flags(compute_test_value="raise"): def bn_ref(x, G, B, M, V): n = (x - M) / V return n * G + B rng = np.random.default_rng(1234) X = 1 + rng.random([10, 20]).astype("float32") B = 1 + rng.random([20]).astype("float32") G = 1 + rng.random([20]).astype("float32") M = 1 + rng.random([20]).astype("float32") V = 1 + rng.random([20]).astype("float32") x = matrix("x") b = vector("b") g = vector("g") m = vector("m") v = vector("v") x.tag.test_value = rng.random((2, 2)).astype(aesara.config.floatX) b.tag.test_value = rng.random((2)).astype(aesara.config.floatX) g.tag.test_value = rng.random((2)).astype(aesara.config.floatX) m.tag.test_value = rng.random((2)).astype(aesara.config.floatX) v.tag.test_value = rng.random((2)).astype(aesara.config.floatX) bn_ref_op = bn_ref(x, g, b, m, v) f_ref = aesara.function([x, b, g, m, v], [bn_ref_op]) res_ref = f_ref(X, G, B, M, V) for mode in ["low_mem", "high_mem"]: bn_op = batchnorm.batch_normalization(x, g, b, m, v, mode=mode) f = aesara.function([x, b, g, m, v], [bn_op]) res = f(X, G, B, M, V) utt.assert_allclose(res_ref, res)
def conv_bn(inputs, gamma, beta, mean, std): return batchnorm.batch_normalization( inputs, gamma.dimshuffle("x", 0, "x", "x"), beta.dimshuffle("x", 0, "x", "x"), mean.dimshuffle("x", 0, "x", "x"), std.dimshuffle("x", 0, "x", "x"), mode=mode, )
def test_bn_feature_maps(): def bn_ref(x, G, B, M, V): n = (x - M) / V return n * G + B rng = np.random.default_rng(1234) X = 1 + rng.random([2, 3, 4, 4]).astype("float32") B = 1 + rng.random([3]).astype("float32") G = 1 + rng.random([3]).astype("float32") M = 1 + rng.random([3]).astype("float32") V = 1 + rng.random([3]).astype("float32") x = tensor4("x") b = vector("b") g = vector("g") m = vector("m") v = vector("v") bn_ref_op = bn_ref( x, g.dimshuffle("x", 0, "x", "x"), b.dimshuffle("x", 0, "x", "x"), m.dimshuffle("x", 0, "x", "x"), v.dimshuffle("x", 0, "x", "x"), ) f_ref = aesara.function([x, b, g, m, v], [bn_ref_op]) res_ref = f_ref(X, G, B, M, V) for mode in ["low_mem", "high_mem"]: bn_op = batchnorm.batch_normalization( x, g.dimshuffle("x", 0, "x", "x"), b.dimshuffle("x", 0, "x", "x"), m.dimshuffle("x", 0, "x", "x"), v.dimshuffle("x", 0, "x", "x"), mode=mode, ) f = aesara.function([x, b, g, m, v], [bn_op]) res = f(X, G, B, M, V) utt.assert_allclose(res_ref, res) def conv_bn(inputs, gamma, beta, mean, std): return batchnorm.batch_normalization( inputs, gamma.dimshuffle("x", 0, "x", "x"), beta.dimshuffle("x", 0, "x", "x"), mean.dimshuffle("x", 0, "x", "x"), std.dimshuffle("x", 0, "x", "x"), mode=mode, ) utt.verify_grad(conv_bn, [X, G, B, M, V])