def testParallelVarianceOneAtATime(self): x = np.random.randn(5, 10) n, avg, m2 = 1, x[0], 0 for row in range(1, 5): n, avg, m2 = tensor_normalizer.parallel_variance_calculation( n_a=n, avg_a=avg, m2_a=m2, n_b=1, avg_b=x[row], m2_b=0) var = m2 / n self.assertAllClose(avg, x.mean(axis=0)) self.assertAllClose(var, x.var(axis=0))
def testParallelVarianceForOneGroup(self): x = tf.constant(np.random.randn(5, 10)) n = 5 avg, var = tf.nn.moments(x, axes=[0]) m2 = var * n new_n, new_avg, new_m2, _ = tensor_normalizer.parallel_variance_calculation( n, avg, m2, n_b=0, avg_b=0, m2_b=0, m2_b_c=0) new_var = new_m2 / n (avg, var, new_avg, new_var) = self.evaluate( (avg, var, new_avg, new_var)) self.assertEqual(new_n, 5) self.assertAllClose(new_avg, avg) self.assertAllClose(new_var, var)
def testParallelVarianceCombinesGroups(self): x1 = tf.constant(np.random.randn(5, 10)) x2 = tf.constant(np.random.randn(15, 10)) n1 = 5 n2 = 15 avg1, var1 = tf.nn.moments(x1, axes=[0]) avg2, var2 = tf.nn.moments(x2, axes=[0]) m2_1 = var1 * n1 m2_2 = var2 * n2 n, avg, m2 = tensor_normalizer.parallel_variance_calculation( n1, avg1, m2_1, n2, avg2, m2_2) var = m2 / n avg_true, var_true = tf.nn.moments(tf.concat((x1, x2), axis=0), axes=[0]) avg, var, avg_true, var_true = self.evaluate( (avg, var, avg_true, var_true)) self.assertAllClose(avg, avg_true) self.assertAllClose(var, var_true)
def testParallelVarianceCombinesGroups(self): x1 = tf.constant(np.random.randn(5, 10)) x2 = tf.constant(np.random.randn(15, 10)) n1 = 5 n2 = 15 avg1 = tf.math.reduce_mean(x1, axis=[0]) m2_1 = tf.math.reduce_sum(tf.math.squared_difference(x1, avg1), axis=[0]) avg2 = tf.math.reduce_mean(x2, axis=[0]) m2_2 = tf.math.reduce_sum(tf.math.squared_difference(x2, avg2), axis=[0]) m2_c = m2_2 * 0. n, avg, m2, _ = tensor_normalizer.parallel_variance_calculation( n1, avg1, m2_1, n2, avg2, m2_2, m2_c) var = m2 / n avg_true, var_true = tf.nn.moments(tf.concat((x1, x2), axis=0), axes=[0]) avg, var, avg_true, var_true = self.evaluate( (avg, var, avg_true, var_true)) self.assertAllClose(avg, avg_true) self.assertAllClose(var, var_true)