コード例 #1
0
ファイル: test_scripts.py プロジェクト: berleon/deepdecoder
def test_augmentation_data_generator(tmpdir):
    dist = DistributionCollection(examplary_tag_distribution())
    dset_fname = str(tmpdir.join("dset.hdf5"))
    samples = 6000
    dset = DistributionHDF5Dataset(dset_fname, nb_samples=samples,
                                   distribution=dist)
    labels = dist.sample(samples)
    labels = dist.normalize(labels)
    fake = np.random.random((samples, 1, 8, 8))
    discriminator = np.random.random((samples, 1))
    dset.append(labels=labels, fake=fake, discriminator=discriminator)
    dset.close()

    dset = DistributionHDF5Dataset(dset_fname)
    bs = 32
    names = ['labels', 'fake']
    assert 'labels' in next(dset.iter(bs, names))
    assert next(dset.iter(bs))['labels'].dtype.names == tuple(dist.names)

    dset_iters = [lambda bs: bit_split(dataset_iterator(dset, bs))]
    data_gen = lambda bs: zip_dataset_iterators(dset_iters, bs)
    label_names = ['bit_0', 'bit_11', 'x_rotation']
    aug_gen = augmentation_data_generator(data_gen, 'fake', label_names)
    outs = next(aug_gen(bs))
    assert len(outs[0]) == 32
    assert len(outs[1]) == len(label_names)

    gen = aug_gen(bs)
    for i, batch in enumerate(gen):
        if i == 2*samples // bs:
            break
        assert batch is not None
        assert batch[0].shape == (bs, 1, 8, 8)
        assert len(batch[1]) == len(label_names)
コード例 #2
0
def test_distribution_collection_normalization():
    dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)])
    bs = 10
    arr = dists.sample(bs)
    norm_arr = dists.normalize(arr)
    denorm_arr = dists.denormalize(norm_arr)
    assert (arr == denorm_arr).all()
コード例 #3
0
def test_augmentation_data_generator(tmpdir):
    dist = DistributionCollection(examplary_tag_distribution())
    dset_fname = str(tmpdir.join("dset.hdf5"))
    samples = 6000
    dset = DistributionHDF5Dataset(dset_fname,
                                   nb_samples=samples,
                                   distribution=dist)
    labels = dist.sample(samples)
    labels = dist.normalize(labels)
    fake = np.random.random((samples, 1, 8, 8))
    discriminator = np.random.random((samples, 1))
    dset.append(labels=labels, fake=fake, discriminator=discriminator)
    dset.close()

    dset = DistributionHDF5Dataset(dset_fname)
    bs = 32
    names = ['labels', 'fake']
    assert 'labels' in next(dset.iter(bs, names))
    assert next(dset.iter(bs))['labels'].dtype.names == tuple(dist.names)

    dset_iters = [lambda bs: bit_split(dataset_iterator(dset, bs))]
    data_gen = lambda bs: zip_dataset_iterators(dset_iters, bs)
    label_names = ['bit_0', 'bit_11', 'x_rotation']
    aug_gen = augmentation_data_generator(data_gen, 'fake', label_names)
    outs = next(aug_gen(bs))
    assert len(outs[0]) == 32
    assert len(outs[1]) == len(label_names)

    gen = aug_gen(bs)
    for i, batch in enumerate(gen):
        if i == 2 * samples // bs:
            break
        assert batch is not None
        assert batch[0].shape == (bs, 1, 8, 8)
        assert len(batch[1]) == len(label_names)
コード例 #4
0
def test_distribution_collection_serialization():
    dists = DistributionCollection([('norm', Normal(5, 2), 2), ('bern', Bernoulli(), 5)])
    json_str = dists.to_json()
    dists_loaded = load_from_json(json_str)
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    n = 50
    arr = dists.sample(n)
    arr_norm = dists.normalize(arr)
    arr_norm_loaded = dists_loaded.normalize(arr)
    for name in dists.names:
        np.testing.assert_allclose(arr_norm[name], arr_norm_loaded[name])

    n = 10000
    arr = dists.sample(n)
    arr_loaded = dists_loaded.sample(n)
    for name in dists.names:
        assert abs(arr[name].mean() - arr_loaded[name].mean()) <= 0.10
        assert abs(arr[name].std() - arr_loaded[name].std()) <= 0.10

    assert dists_loaded.names == dists.names
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    assert dists_loaded.distributions == dists.distributions
    assert dists_loaded.normalizations == dists.normalizations
    assert dists_loaded.nb_elems == dists.nb_elems
コード例 #5
0
def test_distribution_collection_sampling():
    dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)])

    bs = 10
    arr = dists.sample(bs)
    assert arr["const"].shape == (bs, 2)
    assert (arr["const"] == 5).all()

    assert arr["bern"].shape == (bs, 5)
    assert (np.logical_or(arr["bern"] == 1, arr["bern"] == 0)).all()
コード例 #6
0
def test_store_samples(tmpdir):
    dist = DistributionCollection(examplary_tag_distribution())
    bs = 64
    labels_record = dist.sample(bs)
    labels_record = dist.normalize(labels_record)
    labels = []
    for name in labels_record.dtype.names:
        labels.append(labels_record[name])
    labels = np.concatenate(labels, axis=-1)
    fakes = np.random.random((bs, 1, 8, 8))
    store = StoreSamples(str(tmpdir), dist)
    store.on_epoch_end(0, logs={'samples': {'labels': labels, 'fake': fakes}})
    assert tmpdir.join("00000.hdf5").exists()
コード例 #7
0
def test_generated_3d_tags():
    artist = BlackWhiteArtist(0, 255, 0, 1)
    label_dist = DistributionCollection(examplary_tag_distribution())
    bs = 8
    labels, grids = generated_3d_tags(label_dist, batch_size=bs, artist=artist)
    assert grids.shape == (bs, 1, 64, 64)
    assert list(labels.dtype.names) == label_dist.names
コード例 #8
0
ファイル: gt_to_hdf5.py プロジェクト: janek/bb_utils
def run(gt_file,
        videos,
        images,
        visualize_debug,
        output,
        fix_utc_2014,
        nb_bits=12):
    """
    Converts bb_binary ground truth Cap'n Proto files to hdf5 files and
    extracts the corresponding rois from videos or images.
    """
    def get_filenames(f):
        if f is None:
            return []
        else:
            return [line.rstrip('\n') for line in f.readlines()]

    gen_factory = FrameGeneratorFactory(get_filenames(videos),
                                        get_filenames(images))
    if os.path.exists(output):
        os.remove(output)

    distribution = DistributionCollection([('bits', Bernoulli(), nb_bits)])
    dset = DistributionHDF5Dataset(output, distribution)
    camIdxs = []
    periods = []
    for fname in gt_file:
        fc = load_frame_container(fname)
        camIdx, start_dt, end_dt = parse_video_fname(fname)
        if fix_utc_2014 and start_dt.year == 2014:
            start_dt -= timedelta(hours=2)
        gt_frames = []
        gen = gen_factory.get_generator(camIdx, start_dt)
        for frame, (video_frame, video_filename) in zip(fc.frames, gen):
            gt = {}
            np_frame = convert_frame_to_numpy(frame)
            rois, mask, positions = extract_gt_rois(np_frame, video_frame,
                                                    start_dt)
            for name in np_frame.dtype.names:
                gt[name] = np_frame[name][mask]
            bits = [int_id_to_binary(id)[::-1] for id in gt["decodedId"]]
            gt["bits"] = 2 * np.array(bits, dtype=np.float) - 1
            gt["tags"] = 2 * (rois / 255.).astype(np.float16) - 1
            gt['filename'] = os.path.basename(video_filename)
            gt['camIdx'] = camIdx
            gt_frames.append(gt)
            print('.', end='', flush=True)
        print()
        gt_period = GTPeriod(camIdx, start_dt, end_dt, fname, gt_frames)

        periods.append(
            [int(gt_period.start.timestamp()),
             int(gt_period.end.timestamp())])
        camIdxs.append(gt_period.camIdx)
        append_gt_to_hdf5(gt_period, dset)

    dset.attrs['periods'] = np.array(periods)
    dset.attrs['camIdxs'] = np.array(camIdxs)
    visualize_detection_tiles(dset, os.path.splitext(output)[0])
    dset.close()
コード例 #9
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict, self.model.output_names)
     self.model._make_predict_function()
コード例 #10
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict,
                                     self.model.output_names)
     self.model._make_predict_function()
コード例 #11
0
ファイル: test_data.py プロジェクト: berleon/deepdecoder
def test_distribution_hdf5_dataset(tmpdir):
    with pytest.raises(Exception):
        DistributionHDF5Dataset(
            str(tmpdir.join('dataset_no_distribution.hdf5')), nb_samples=1000)

    dist = DistributionCollection(examplary_tag_distribution(nb_bits=12))
    labels = dist.sample(32)
    image = np.random.random((32, 1, 8, 8))
    dset = DistributionHDF5Dataset(
        str(tmpdir.join('dataset.hdf5')), distribution=dist, nb_samples=1000)
    dset.append(labels=labels, image=image)
    for name in dist.names:
        assert name in dset
    for batch in dset.iter(batch_size=32):
        for name in dist.names:
            assert name not in batch
        assert 'labels' in batch
        assert batch['labels'].dtype == dist.norm_dtype
        break
コード例 #12
0
def test_distribution_hdf5_dataset(tmpdir):
    with pytest.raises(Exception):
        DistributionHDF5Dataset(str(
            tmpdir.join('dataset_no_distribution.hdf5')),
                                nb_samples=1000)

    dist = DistributionCollection(examplary_tag_distribution(nb_bits=12))
    labels = dist.sample(32)
    image = np.random.random((32, 1, 8, 8))
    dset = DistributionHDF5Dataset(str(tmpdir.join('dataset.hdf5')),
                                   distribution=dist,
                                   nb_samples=1000)
    dset.append(labels=labels, image=image)
    for name in dist.names:
        assert name in dset
    for batch in dset.iter(batch_size=32):
        for name in dist.names:
            assert name not in batch
        assert 'labels' in batch
        assert batch['labels'].dtype == dist.norm_dtype
        break
コード例 #13
0
def default_tag_distribution():
    tag_dist_params = examplary_tag_distribution()
    angle = to_radians(60)
    tag_dist_params["x_rotation"] = TruncNormal(-angle, angle, 0, angle / 2)
    tag_dist_params["y_rotation"] = TruncNormal(-angle, angle, 0, angle / 2)
    tag_dist_params["radius"] = TruncNormal(22, 26, 24.0, 1.3)
    tag_dist_params["center"] = (Uniform(-16, 16), 2)

    tag_dist_params["bulge_factor"] = Uniform(0.4, 0.8)
    tag_dist_params["focal_length"] = Uniform(2, 4)
    tag_dist_params["inner_ring_radius"] = Uniform(0.42, 0.48)
    tag_dist_params["middle_ring_radius"] = Constant(0.8)
    tag_dist_params["outer_ring_radius"] = Constant(1.)
    return DistributionCollection(tag_dist_params)