def test_update_mean_lmbda(): N = 100 D = 2 X = np.random.randn(N, D) X2 = np.random.randn(N, D) lmbdas = np.ones(N) * 0.1 old_mean = np.mean(X, axis=0) full_mean = update_mean_lmbda(X2, old_mean, lmbdas) for x, lmbda in zip(X2, lmbdas): old_mean = (1 - lmbda) * old_mean + lmbda * x assert_allclose(old_mean, full_mean)
def test_weights_to_lmbdas_produces_mean_weighted(): N = 20 X = np.random.randn(N) weights = np.random.rand(N) sum_old_weights = weights[0] lmbdas = weights_to_lmbdas(sum_old_weights, weights[1:]) old_mean = X[0] full_mean = update_mean_lmbda(X[1:], old_mean, lmbdas) X_weighted = np.array([X[i] * weights[i] for i in range(N)]) full_mean_batch = np.sum(X_weighted, axis=0) / np.sum(weights) assert_allclose(full_mean_batch, full_mean)
def test_weights_to_lmbdas_produces_mean(): N = 30 D = 2 X = np.random.randn(N, D) full_mean_batch = np.mean(X, axis=0) sum_old_weights = 1 new_weights = np.ones(N - 1) lmbdas = weights_to_lmbdas(sum_old_weights, new_weights) old_mean = X[0] full_mean = update_mean_lmbda(X[1:], old_mean, lmbdas) assert_allclose(full_mean_batch, full_mean)