def test_distribution_collection_normalization():
    dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)])
    bs = 10
    arr = dists.sample(bs)
    norm_arr = dists.normalize(arr)
    denorm_arr = dists.denormalize(norm_arr)
    assert (arr == denorm_arr).all()
Example #2
0
def test_augmentation_data_generator(tmpdir):
    dist = DistributionCollection(examplary_tag_distribution())
    dset_fname = str(tmpdir.join("dset.hdf5"))
    samples = 6000
    dset = DistributionHDF5Dataset(dset_fname, nb_samples=samples,
                                   distribution=dist)
    labels = dist.sample(samples)
    labels = dist.normalize(labels)
    fake = np.random.random((samples, 1, 8, 8))
    discriminator = np.random.random((samples, 1))
    dset.append(labels=labels, fake=fake, discriminator=discriminator)
    dset.close()

    dset = DistributionHDF5Dataset(dset_fname)
    bs = 32
    names = ['labels', 'fake']
    assert 'labels' in next(dset.iter(bs, names))
    assert next(dset.iter(bs))['labels'].dtype.names == tuple(dist.names)

    dset_iters = [lambda bs: bit_split(dataset_iterator(dset, bs))]
    data_gen = lambda bs: zip_dataset_iterators(dset_iters, bs)
    label_names = ['bit_0', 'bit_11', 'x_rotation']
    aug_gen = augmentation_data_generator(data_gen, 'fake', label_names)
    outs = next(aug_gen(bs))
    assert len(outs[0]) == 32
    assert len(outs[1]) == len(label_names)

    gen = aug_gen(bs)
    for i, batch in enumerate(gen):
        if i == 2*samples // bs:
            break
        assert batch is not None
        assert batch[0].shape == (bs, 1, 8, 8)
        assert len(batch[1]) == len(label_names)
def test_distribution_collection_serialization():
    dists = DistributionCollection([('norm', Normal(5, 2), 2), ('bern', Bernoulli(), 5)])
    json_str = dists.to_json()
    dists_loaded = load_from_json(json_str)
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    n = 50
    arr = dists.sample(n)
    arr_norm = dists.normalize(arr)
    arr_norm_loaded = dists_loaded.normalize(arr)
    for name in dists.names:
        np.testing.assert_allclose(arr_norm[name], arr_norm_loaded[name])

    n = 10000
    arr = dists.sample(n)
    arr_loaded = dists_loaded.sample(n)
    for name in dists.names:
        assert abs(arr[name].mean() - arr_loaded[name].mean()) <= 0.10
        assert abs(arr[name].std() - arr_loaded[name].std()) <= 0.10

    assert dists_loaded.names == dists.names
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    assert dists_loaded.distributions == dists.distributions
    assert dists_loaded.normalizations == dists.normalizations
    assert dists_loaded.nb_elems == dists.nb_elems
def test_distribution_collection_sampling():
    dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)])

    bs = 10
    arr = dists.sample(bs)
    assert arr["const"].shape == (bs, 2)
    assert (arr["const"] == 5).all()

    assert arr["bern"].shape == (bs, 5)
    assert (np.logical_or(arr["bern"] == 1, arr["bern"] == 0)).all()
Example #5
0
def test_store_samples(tmpdir):
    dist = DistributionCollection(examplary_tag_distribution())
    bs = 64
    labels_record = dist.sample(bs)
    labels_record = dist.normalize(labels_record)
    labels = []
    for name in labels_record.dtype.names:
        labels.append(labels_record[name])
    labels = np.concatenate(labels, axis=-1)
    fakes = np.random.random((bs, 1, 8, 8))
    store = StoreSamples(str(tmpdir), dist)
    store.on_epoch_end(0, logs={'samples': {'labels': labels, 'fake': fakes}})
    assert tmpdir.join("00000.hdf5").exists()
Example #6
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict, self.model.output_names)
     self.model._make_predict_function()
Example #7
0
def test_distribution_hdf5_dataset(tmpdir):
    with pytest.raises(Exception):
        DistributionHDF5Dataset(
            str(tmpdir.join('dataset_no_distribution.hdf5')), nb_samples=1000)

    dist = DistributionCollection(examplary_tag_distribution(nb_bits=12))
    labels = dist.sample(32)
    image = np.random.random((32, 1, 8, 8))
    dset = DistributionHDF5Dataset(
        str(tmpdir.join('dataset.hdf5')), distribution=dist, nb_samples=1000)
    dset.append(labels=labels, image=image)
    for name in dist.names:
        assert name in dset
    for batch in dset.iter(batch_size=32):
        for name in dist.names:
            assert name not in batch
        assert 'labels' in batch
        assert batch['labels'].dtype == dist.norm_dtype
        break