Example #1
0
def test_handles_ambiguous_series_no(mocker):
    # Enough so that there are duplicate series numbers
    dataroot, DIV_data, _, _ = fake_geo_data(11)

    for i, data in enumerate(DIV_data):
        DIV_data[i] = np.interp(data,
                                (np.min(data.ravel()), np.max(data.ravel())),
                                [-1, 1])

    p = GeoPickler(dataroot, 'out_dir')

    p.collect_all()

    p.group_by_series()

    mocker.patch('torch.save')

    p.pickle_series(0, 1, 1000, 4, 6)

    path = torch.save.call_args[0][1]
    data = torch.save.call_args[0][0]

    assert (path == 'out_dir/00001.pkl')

    assert ((data['A_DIV'] == DIV_data[1]).all())
    assert (all([
        key in data.keys() for key in [
            'A', 'A_DIV', 'A_Vx', 'A_Vy', 'mask_locs', 'folder_name',
            'series_number', 'mask_size', 'min_pix_in_mask'
        ]
    ]))
Example #2
0
def test_build_one_hot(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    one_hot = data_dict['A']

    assert ([i in np.where(DIV > 1000) for i in np.where(one_hot[:, :, 0])])
    assert ([
        i in np.where(np.logical_and(DIV < 1000, DIV < -1000))
        for i in np.where(one_hot[:, :, 1])
    ])
    assert ([i in np.where(DIV < -1000) for i in np.where(one_hot[:, :, 2])])
Example #3
0
def test_skip_save_if_no_mask_locations(fake_geo_data, mocker):
    dataroot, _, _, _ = fake_geo_data

    p = GeoPickler(dataroot, 'out_dir')

    p.collect_all()

    p.group_by_series()

    mocker.patch('torch.save')

    p.pickle_series(0, 0, 1000, 4, 1000)

    torch.save.assert_not_called()
Example #4
0
def test_build_data_dict(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    assert ((data_dict['A_DIV'] == DIV_datas[0]).all())
    assert ((data_dict['A_Vx'] == Vx_datas[0]).all())
    assert ((data_dict['A_Vy'] == Vy_datas[0]).all())
Example #5
0
def test_resized(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot, row_height=18)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)
    p.create_one_hot(data_dict, 1000)

    assert ((data_dict['A_DIV'].shape == (18, 36)))
    assert ((data_dict['A_Vx'].shape == (18, 36)))
    assert ((data_dict['A_Vy'].shape == (18, 36)))
    assert ((data_dict['A'].shape == (18, 36, 3)))
Example #6
0
def test_searches_subfolders():
    dataroot, _, _, _ = fake_geo_data(1)
    subfolder, _, _, _ = fake_geo_data(2)

    shutil.move(subfolder, dataroot)

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    assert (list(p.folders.keys()) == ['', os.path.basename(subfolder)])

    assert (len(p.get_folder_by_id(0)) == 1)
    assert (len(p.get_folder_by_id(1)) == 2)
Example #7
0
def test_matches_other_file_pattern():
    dataroot, _, _, _ = fake_geo_data(1)

    shutil.move(os.path.join(dataroot, 'serie100000_Vx.dat'),
                os.path.join(dataroot, 'serie1_0_Vx.dat'))
    shutil.move(os.path.join(dataroot, 'serie100000_Vy.dat'),
                os.path.join(dataroot, 'serie1_0_Vy.dat'))

    assert ('serie1_0_Vx.dat' in os.listdir(dataroot))
    assert ('serie1_0_Vy.dat' in os.listdir(dataroot))

    p = GeoPickler(dataroot)

    p.collect_all()
    p.group_by_series()

    assert (len(p.get_folder_by_id(0)) == 1)
Example #8
0
def test_normalises_continuous_data(fake_geo_data):
    dataroot, _, _, _ = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    p.normalise_continuous_data(data_dict)

    assert (np.max(data_dict['A_DIV'].ravel()) == 1.0)
    assert (np.min(data_dict['A_DIV'].ravel()) == -1.0)

    assert (np.max(data_dict['A_Vx'].ravel()) == 1.0)
    assert (np.min(data_dict['A_Vx'].ravel()) == -1.0)

    assert (np.max(data_dict['A_Vy'].ravel()) == 1.0)
    assert (np.min(data_dict['A_Vy'].ravel()) == -1.0)
Example #9
0
def test_mask_params_stored_in_dict(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    p.get_mask_loc(data_dict, 4, 6)

    assert (data_dict['mask_size'] == 4)
    assert (data_dict['min_pix_in_mask'] == 6)
Example #10
0
def test_pickling_contains_all_data(fake_geo_data, mocker):
    dataroot, _, _, _ = fake_geo_data

    p = GeoPickler(dataroot, 'out_dir')

    p.collect_all()

    p.group_by_series()

    mocker.patch('torch.save')

    p.pickle_series(0, 0, 1000, 4, 6)

    path = torch.save.call_args[0][1]
    data = torch.save.call_args[0][0]

    assert (path == 'out_dir/00000.pkl')
    assert (all([
        key in data.keys() for key in [
            'A', 'A_DIV', 'A_Vx', 'A_Vy', 'mask_locs', 'folder_name',
            'series_number', 'mask_size', 'min_pix_in_mask'
        ]
    ]))
Example #11
0
def test_groups_by_series():
    dataroot, _, _, _ = fake_geo_data(3)

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    folder = p.get_folder_by_id(0)

    assert (len(folder) == 3)

    assert (p.get_series_in_folder(0) == [0, 1, 2])

    assert (folder[0] == [
        'serie100000_DIV.dat', 'serie100000_Vx.dat', 'serie100000_Vy.dat'
    ])
    assert (folder[1] == [
        'serie100001_DIV.dat', 'serie100001_Vx.dat', 'serie100001_Vy.dat'
    ])
    assert (folder[2] == [
        'serie100002_DIV.dat', 'serie100002_Vx.dat', 'serie100002_Vy.dat'
    ])
Example #12
0
def test_mask_location(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    p.get_mask_loc(data_dict, 4, 6)

    one_hot = data_dict['A']

    mask_loc = data_dict['mask_locs']

    assert (len(mask_loc) > 0)

    for x in range(one_hot.shape[1] - 4):
        for y in range(one_hot.shape[0] - 4):
            sum1 = np.sum(one_hot[y:y + 4, x:x + 4, 0])
            sum2 = np.sum(one_hot[y:y + 4, x:x + 4, 2])

            if (y, x) in mask_loc:
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) >= 6)
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 2]) >= 6)
            else:
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) < 6
                        or np.sum(one_hot[y:y + 4, x:x + 4, 2]) < 6)
Example #13
0
dataroot = os.path.expanduser(
    '/storage/Datasets/Geology-NicolasColtice/DS2-1810-RAW-DAT/train')
# dataroot = os.path.expanduser('/storage/Datasets/Geology-NicolasColtice/new_data/test')
# dataroot = os.path.expanduser('~/data/new_geo_data/test')
# dataroot = os.path.expanduser('~/data/new_geo_data/validation')

# out_dir = os.path.expanduser('/storage/Datasets/Geology-NicolasColtice/old_pytorch_records/train')
# out_dir = os.path.expanduser('/storage/Datasets/Geology-NicolasColtice/pytorch_records_new_thresh/test')
out_dir = os.path.expanduser(
    '/storage/Datasets/Geology-NicolasColtice/voronoi_geo_mix/train/0')
# out_dir = os.path.expanduser('~/data/geo_data_pkl/test')
# out_dir = os.path.expanduser('~/data/geo_data_pkl/validation')

p = GeoPickler(dataroot, out_dir, 256)

p.collect_all()

p.group_by_series()

groups = [(1, [10, 11, 12, 13, 18, 2, 3, 4, 5, 6, 9]),
          (2, [14, 15, 16, 17, 23]), (3, [19, 20, 21, 22])]

#thresholds = [0.045, 0.03, 0.03]

#thresholds = {str(folder): thresholds[i-1] for i, folders in groups for folder in folders }

thresholds = 1000

p.pickle_all(thresholds, 64, 10, verbose=True, skip_existing=True)
# p.pickle_all(1000, 100, 10, verbose=True, skip_existing=True)
Example #14
0
def test_pickling_preserves_folder_structure(mocker):
    dataroot, DIV_base, Vx_base, Vy_base = fake_geo_data(1)
    subfolder1, DIV_sub1, Vx_sub1, Vy_sub1 = fake_geo_data(2)
    subfolder2, DIV_sub2, Vx_sub2, Vy_sub2 = fake_geo_data(2)

    for data_group in [
            'DIV_base', 'Vx_base', 'Vy_base', 'DIV_sub1', 'Vx_sub1', 'Vy_sub1',
            'DIV_sub2', 'Vx_sub2', 'Vy_sub2'
    ]:
        for i, data in enumerate(eval(data_group)):
            eval(data_group)[i] = np.interp(
                data, (np.min(data.ravel()), np.max(data.ravel())), [-1, 1])

    shutil.move(subfolder1, dataroot)
    shutil.move(subfolder2, dataroot)

    p = GeoPickler(dataroot, 'out_dir')

    p.collect_all()

    p.group_by_series()

    # Doesn't always load in order, but it's not a big deal
    assert (sorted(p.folders.keys()) == sorted(
        ['', os.path.basename(subfolder1),
         os.path.basename(subfolder2)]))

    assert (len(p.get_folder_by_id(0)) == 1)
    assert (len(p.get_folder_by_id(1)) == 2)
    assert (len(p.get_folder_by_id(2)) == 2)

    dirs = list(p.folders.keys())

    mocker.patch('torch.save')

    # Just set mask threshold to zero, otherwise it randomly fails if not enough
    # pixels in randomly generated test data
    p.pickle_all(1000, 4, 0)

    assert (len(torch.save.call_args_list) == 5)

    for args, kwargs in torch.save.call_args_list:
        series_number = args[0]['series_number']

        # It doesn't matter which order they're written in
        if args[1] == os.path.join('out_dir',
                                   '{:05}.pkl'.format(series_number)):
            assert ((args[0]['A_DIV'] == DIV_base[0]).all())
            assert ((args[0]['A_Vx'] == Vx_base[0]).all())
            assert ((args[0]['A_Vy'] == Vy_base[0]).all())
        elif args[1] == os.path.join('out_dir', os.path.basename(subfolder1),
                                     '{:05}.pkl'.format(series_number)):
            assert (series_number == 0 or series_number == 1)
            assert ((args[0]['A_DIV'] == DIV_sub1[series_number]).all())
            assert ((args[0]['A_Vx'] == Vx_sub1[series_number]).all())
            assert ((args[0]['A_Vy'] == Vy_sub1[series_number]).all())
        elif args[1] == os.path.join('out_dir', os.path.basename(subfolder2),
                                     '{:05}.pkl'.format(series_number)):
            assert (series_number == 0 or series_number == 1)
            assert ((args[0]['A_DIV'] == DIV_sub2[series_number]).all())
            assert ((args[0]['A_Vx'] == Vx_sub2[series_number]).all())
            assert ((args[0]['A_Vy'] == Vy_sub2[series_number]).all())
        else:
            assert False, args[1]