def test_distribution_collection_serialization(): dists = DistributionCollection([('norm', Normal(5, 2), 2), ('bern', Bernoulli(), 5)]) json_str = dists.to_json() dists_loaded = load_from_json(json_str) assert dists_loaded.dtype == dists.dtype assert dists_loaded.norm_dtype == dists.norm_dtype n = 50 arr = dists.sample(n) arr_norm = dists.normalize(arr) arr_norm_loaded = dists_loaded.normalize(arr) for name in dists.names: np.testing.assert_allclose(arr_norm[name], arr_norm_loaded[name]) n = 10000 arr = dists.sample(n) arr_loaded = dists_loaded.sample(n) for name in dists.names: assert abs(arr[name].mean() - arr_loaded[name].mean()) <= 0.10 assert abs(arr[name].std() - arr_loaded[name].std()) <= 0.10 assert dists_loaded.names == dists.names assert dists_loaded.dtype == dists.dtype assert dists_loaded.norm_dtype == dists.norm_dtype assert dists_loaded.distributions == dists.distributions assert dists_loaded.normalizations == dists.normalizations assert dists_loaded.nb_elems == dists.nb_elems
def test_augmentation_data_generator(tmpdir): dist = DistributionCollection(examplary_tag_distribution()) dset_fname = str(tmpdir.join("dset.hdf5")) samples = 6000 dset = DistributionHDF5Dataset(dset_fname, nb_samples=samples, distribution=dist) labels = dist.sample(samples) labels = dist.normalize(labels) fake = np.random.random((samples, 1, 8, 8)) discriminator = np.random.random((samples, 1)) dset.append(labels=labels, fake=fake, discriminator=discriminator) dset.close() dset = DistributionHDF5Dataset(dset_fname) bs = 32 names = ['labels', 'fake'] assert 'labels' in next(dset.iter(bs, names)) assert next(dset.iter(bs))['labels'].dtype.names == tuple(dist.names) dset_iters = [lambda bs: bit_split(dataset_iterator(dset, bs))] data_gen = lambda bs: zip_dataset_iterators(dset_iters, bs) label_names = ['bit_0', 'bit_11', 'x_rotation'] aug_gen = augmentation_data_generator(data_gen, 'fake', label_names) outs = next(aug_gen(bs)) assert len(outs[0]) == 32 assert len(outs[1]) == len(label_names) gen = aug_gen(bs) for i, batch in enumerate(gen): if i == 2*samples // bs: break assert batch is not None assert batch[0].shape == (bs, 1, 8, 8) assert len(batch[1]) == len(label_names)
def test_augmentation_data_generator(tmpdir): dist = DistributionCollection(examplary_tag_distribution()) dset_fname = str(tmpdir.join("dset.hdf5")) samples = 6000 dset = DistributionHDF5Dataset(dset_fname, nb_samples=samples, distribution=dist) labels = dist.sample(samples) labels = dist.normalize(labels) fake = np.random.random((samples, 1, 8, 8)) discriminator = np.random.random((samples, 1)) dset.append(labels=labels, fake=fake, discriminator=discriminator) dset.close() dset = DistributionHDF5Dataset(dset_fname) bs = 32 names = ['labels', 'fake'] assert 'labels' in next(dset.iter(bs, names)) assert next(dset.iter(bs))['labels'].dtype.names == tuple(dist.names) dset_iters = [lambda bs: bit_split(dataset_iterator(dset, bs))] data_gen = lambda bs: zip_dataset_iterators(dset_iters, bs) label_names = ['bit_0', 'bit_11', 'x_rotation'] aug_gen = augmentation_data_generator(data_gen, 'fake', label_names) outs = next(aug_gen(bs)) assert len(outs[0]) == 32 assert len(outs[1]) == len(label_names) gen = aug_gen(bs) for i, batch in enumerate(gen): if i == 2 * samples // bs: break assert batch is not None assert batch[0].shape == (bs, 1, 8, 8) assert len(batch[1]) == len(label_names)
def test_distribution_collection_normalization(): dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)]) bs = 10 arr = dists.sample(bs) norm_arr = dists.normalize(arr) denorm_arr = dists.denormalize(norm_arr) assert (arr == denorm_arr).all()
def test_distribution_collection_sampling(): dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)]) bs = 10 arr = dists.sample(bs) assert arr["const"].shape == (bs, 2) assert (arr["const"] == 5).all() assert arr["bern"].shape == (bs, 5) assert (np.logical_or(arr["bern"] == 1, arr["bern"] == 0)).all()
def test_store_samples(tmpdir): dist = DistributionCollection(examplary_tag_distribution()) bs = 64 labels_record = dist.sample(bs) labels_record = dist.normalize(labels_record) labels = [] for name in labels_record.dtype.names: labels.append(labels_record[name]) labels = np.concatenate(labels, axis=-1) fakes = np.random.random((bs, 1, 8, 8)) store = StoreSamples(str(tmpdir), dist) store.on_epoch_end(0, logs={'samples': {'labels': labels, 'fake': fakes}}) assert tmpdir.join("00000.hdf5").exists()
def test_distribution_hdf5_dataset(tmpdir): with pytest.raises(Exception): DistributionHDF5Dataset( str(tmpdir.join('dataset_no_distribution.hdf5')), nb_samples=1000) dist = DistributionCollection(examplary_tag_distribution(nb_bits=12)) labels = dist.sample(32) image = np.random.random((32, 1, 8, 8)) dset = DistributionHDF5Dataset( str(tmpdir.join('dataset.hdf5')), distribution=dist, nb_samples=1000) dset.append(labels=labels, image=image) for name in dist.names: assert name in dset for batch in dset.iter(batch_size=32): for name in dist.names: assert name not in batch assert 'labels' in batch assert batch['labels'].dtype == dist.norm_dtype break
def test_distribution_hdf5_dataset(tmpdir): with pytest.raises(Exception): DistributionHDF5Dataset(str( tmpdir.join('dataset_no_distribution.hdf5')), nb_samples=1000) dist = DistributionCollection(examplary_tag_distribution(nb_bits=12)) labels = dist.sample(32) image = np.random.random((32, 1, 8, 8)) dset = DistributionHDF5Dataset(str(tmpdir.join('dataset.hdf5')), distribution=dist, nb_samples=1000) dset.append(labels=labels, image=image) for name in dist.names: assert name in dset for batch in dset.iter(batch_size=32): for name in dist.names: assert name not in batch assert 'labels' in batch assert batch['labels'].dtype == dist.norm_dtype break