예제 #1
0
def test_build_one_hot(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    one_hot = data_dict['A']

    assert ([i in np.where(DIV > 1000) for i in np.where(one_hot[:, :, 0])])
    assert ([
        i in np.where(np.logical_and(DIV < 1000, DIV < -1000))
        for i in np.where(one_hot[:, :, 1])
    ])
    assert ([i in np.where(DIV < -1000) for i in np.where(one_hot[:, :, 2])])
예제 #2
0
def test_build_data_dict(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    assert ((data_dict['A_DIV'] == DIV_datas[0]).all())
    assert ((data_dict['A_Vx'] == Vx_datas[0]).all())
    assert ((data_dict['A_Vy'] == Vy_datas[0]).all())
예제 #3
0
def test_resized(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot, row_height=18)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)
    p.create_one_hot(data_dict, 1000)

    assert ((data_dict['A_DIV'].shape == (18, 36)))
    assert ((data_dict['A_Vx'].shape == (18, 36)))
    assert ((data_dict['A_Vy'].shape == (18, 36)))
    assert ((data_dict['A'].shape == (18, 36, 3)))
예제 #4
0
def test_normalises_continuous_data(fake_geo_data):
    dataroot, _, _, _ = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    p.normalise_continuous_data(data_dict)

    assert (np.max(data_dict['A_DIV'].ravel()) == 1.0)
    assert (np.min(data_dict['A_DIV'].ravel()) == -1.0)

    assert (np.max(data_dict['A_Vx'].ravel()) == 1.0)
    assert (np.min(data_dict['A_Vx'].ravel()) == -1.0)

    assert (np.max(data_dict['A_Vy'].ravel()) == 1.0)
    assert (np.min(data_dict['A_Vy'].ravel()) == -1.0)
예제 #5
0
def test_mask_params_stored_in_dict(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    p.get_mask_loc(data_dict, 4, 6)

    assert (data_dict['mask_size'] == 4)
    assert (data_dict['min_pix_in_mask'] == 6)
예제 #6
0
def test_mask_location(fake_geo_data):
    dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data

    p = GeoPickler(dataroot)

    p.collect_all()

    p.group_by_series()

    data_dict = p.get_data_dict(0, 0)

    DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000)

    data_dict['A_DIV'] = DIV

    p.create_one_hot(data_dict, 1000)

    p.get_mask_loc(data_dict, 4, 6)

    one_hot = data_dict['A']

    mask_loc = data_dict['mask_locs']

    assert (len(mask_loc) > 0)

    for x in range(one_hot.shape[1] - 4):
        for y in range(one_hot.shape[0] - 4):
            sum1 = np.sum(one_hot[y:y + 4, x:x + 4, 0])
            sum2 = np.sum(one_hot[y:y + 4, x:x + 4, 2])

            if (y, x) in mask_loc:
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) >= 6)
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 2]) >= 6)
            else:
                assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) < 6
                        or np.sum(one_hot[y:y + 4, x:x + 4, 2]) < 6)