def test_mask_params_stored_in_dict(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000) data_dict['A_DIV'] = DIV p.create_one_hot(data_dict, 1000) p.get_mask_loc(data_dict, 4, 6) assert (data_dict['mask_size'] == 4) assert (data_dict['min_pix_in_mask'] == 6)
def test_mask_location(fake_geo_data): dataroot, DIV_datas, Vx_datas, Vy_datas = fake_geo_data p = GeoPickler(dataroot) p.collect_all() p.group_by_series() data_dict = p.get_data_dict(0, 0) DIV = (np.random.randn(*data_dict['A_DIV'].shape) * 20000) data_dict['A_DIV'] = DIV p.create_one_hot(data_dict, 1000) p.get_mask_loc(data_dict, 4, 6) one_hot = data_dict['A'] mask_loc = data_dict['mask_locs'] assert (len(mask_loc) > 0) for x in range(one_hot.shape[1] - 4): for y in range(one_hot.shape[0] - 4): sum1 = np.sum(one_hot[y:y + 4, x:x + 4, 0]) sum2 = np.sum(one_hot[y:y + 4, x:x + 4, 2]) if (y, x) in mask_loc: assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) >= 6) assert (np.sum(one_hot[y:y + 4, x:x + 4, 2]) >= 6) else: assert (np.sum(one_hot[y:y + 4, x:x + 4, 0]) < 6 or np.sum(one_hot[y:y + 4, x:x + 4, 2]) < 6)