def test_uninitialized_add(): data = list(range(100)) batched = BatchedMoments() full = BatchedMoments()(data) batched += full add = BatchedMoments() + full assert full == batched == add
def test_batches(): data = list(range(100)) batched = BatchedMoments() batchsize = 10 for idx in range(len(data) // batchsize): st = idx * batchsize batched(data[st: st + batchsize]) partial = BatchedMoments()(data[: st + batchsize]) assert partial == batched full = BatchedMoments()(data) assert full == batched
def test_multiprocessing_add(): batchsize = 10 samples = batchsize * batchsize * batchsize * batchsize gen1, gen2 = tee( (list(range(n, n + batchsize)) for n in range(0, samples, batchsize))) data = iter(gen1) bm = BatchedMoments()(next(data)) with Pool(processes=multiprocessing.cpu_count()) as pool: for dbm in pool.imap_unordered(BatchedMoments(), data): bm += dbm data = iter(gen2) seq = BatchedMoments()(next(data)) for batch in data: seq += BatchedMoments()(batch) assert seq == bm
def test_axis_shape_random_nd(): shp = tuple([4] * randint(1, 6)) axs = tuple(sample(range(len(shp)), randint(1, len(shp)))) bm = BatchedMoments(axis=axs, shape=shp) assert bm._initialized is True assert bm.shape == shp assert bm.axis == axs
def test_axis_None(): bm = BatchedMoments(axis=None) assert bm._initialized is False # first update call will initialize assert bm._moments_shape is None # moments shape is zero
def test_axis_random_nd(): shp = tuple([4] * randint(1, 16)) bm = BatchedMoments(axis=shp) assert bm._initialized is False assert bm.axis == shp
def test_axis_0d(): shp = tuple() bm = BatchedMoments(axis=shp) assert bm._initialized is False assert bm.axis == shp
def test_batches(): data = [list(range(100))] * 10 full = BatchedMoments()(data) reduced = BatchedMoments(axis=1)(data).reduce(0) assert full == reduced
def test_shape_random_nd(): shp = tuple([4] * randint(1, 16)) bm = BatchedMoments(shape=shp) assert bm._initialized is True # axis=None, thus moments shape is `()` assert bm.shape == tuple()
def test_add(): data = list(range(100)) a = BatchedMoments()(data) b = BatchedMoments()(data[:len(data) // 2]) + BatchedMoments()( data[len(data) // 2:]) assert a == b
def test_axis_shape_0d(): shp = tuple() bm = BatchedMoments(axis=shp, shape=shp) assert bm._initialized is True assert bm.shape == shp assert bm.axis == shp
def test_correctness_1(): data = np.array([1, 2, 3, 4, 5]) # skewness = 0.0 bm = BatchedMoments(axis=0)(data) assert np.allclose(skew(data), bm.skewness, equal_nan=True)
def test_correctness_2(): data = np.array([2, 8, 0, 4, 1, 9, 9, 0]) # skewness = 0.2650554122698573 bm = BatchedMoments(axis=0)(data) assert np.allclose(skew(data), bm.skewness, equal_nan=True)
from batchedmoments import BatchedMoments if __name__ == '__main__': batchsize_exp = range(17) results = [] results_file = Path("naive-bm-comparison.csv") if not results_file.exists(): with tempfile.TemporaryDirectory() as root: image_data = datasets.FashionMNIST(str(root), download=True, train=True, transform=transforms.Compose( [transforms.ToTensor()])) for exp in batchsize_exp: bm = BatchedMoments(axis=(0, 2, 3)) # compare to naive solution means = [] stds = [] data_loader = DataLoader(image_data, batch_size=2**exp) for imgs, _ in data_loader: imgs = imgs.numpy() bm(imgs) means.append(np.mean(imgs)) stds.append(np.std(imgs)) naive_mean = np.mean(means, keepdims=True) naive_std = np.mean(stds, keepdims=True)
def test_commutativity_add(): data = list(range(100)) full = BatchedMoments()(data) add1 = BatchedMoments() + full add2 = full + BatchedMoments() assert add1 == full == add2
def test_correctness(): data = norm.rvs(size=1000, random_state=3) # mean = 0.01728433 bm = BatchedMoments(axis=0)(data) assert np.allclose(tstd(data, ddof=0), bm.std, equal_nan=True)
def test_correctness(): data = norm.rvs(size=1000, random_state=3) # kurtosis = -0.06928694200380558 bm = BatchedMoments(axis=0)(data) assert np.allclose(kurtosis(data), bm.kurtosis, equal_nan=True)
def test_shape_0d(): shp = tuple() bm = BatchedMoments(shape=shp) assert bm._initialized is True assert bm.shape == shp