Ejemplo n.º 1
0
def test_compute_metrics():
    project_path = get_project_root()
    project_path = get_project_root()
    refer_path = project_path / 'res_experiments' / 'predictions' / '20170905_030404.fits'
    refer, names = open_param_file(refer_path)
    pred_path = project_path / 'res_experiments' / 'predictions' / '20170905_030404_common.fits'
    save_path = project_path / 'res_experiments' / 'predictions' / 'test.csv'

    predicted = fits.open(pred_path)
    predicted_data = predicted[0].data
    df = compute_metrics(refer, predicted_data, names, save_path)
    # df.to_csv(save_path)
    assert df.shape == (11, 3)
Ejemplo n.º 2
0
def test_compute_mean_spectrum():

    filename = get_project_root() / 'data' / 'parameters_base.fits'
    batch_size = 1000
    nbatches = 2
    MEAN, STD = compute_mean_spectrum(filename,
                                      batch_size=batch_size,
                                      nbatches=nbatches)
Ejemplo n.º 3
0
 def sample_from_database(self):
     # filename = '/Users/irinaknyazeva/Projects/Solar/InverseProblem/data/parameters_base.fits'
     project_path = get_project_root()
     filename = project_path / 'data' / 'small_parameters_base.fits'
     source = 'database'
     sobj = SpectrumDataset(param_path=filename, source=source)
     sample = sobj[0]
     return sample
Ejemplo n.º 4
0
 def test_init_refer_dataset(self):
     project_path = get_project_root()
     filename = project_path / 'data' / 'hinode_source' / '20140926_170005.fits'
     source = 'refer'
     sobj = SpectrumDataset(param_path=filename, source=source)
     assert isinstance(sobj.param_source, list)
     assert isinstance(sobj[0]['X'][1], float)
     assert sobj.__len__() == 446976
     assert 224 == sobj[0]['X'][0].size
     assert 11 == sobj[0]['Y'].size
Ejemplo n.º 5
0
 def test_init_dataset(self):
     project_path = get_project_root()
     filename = project_path / 'data' / 'small_parameters_base.fits'
     param_array = fits.open(filename)[0].data[:10]
     transform = mlp_batch_rescale()
     # transform=None
     sobj = PregenSpectrumDataset(data_arr=param_array, transform=transform)
     sample = sobj[0]
     assert sample['X'][1].shape[0] == 1
     assert sample['Y'].shape[0] == 11
     assert True
Ejemplo n.º 6
0
 def test_init_database_dataset(self):
     project_path = get_project_root()
     # filename = project_path / 'data' / 'parameters_base_new.fits'
     filename = project_path / 'data' / 'small_parameters_base.fits'
     source = 'database'
     sobj = SpectrumDataset(param_path=filename, source=source)
     assert sobj.param_source.shape[1] == 11
     assert isinstance(sobj[0]['X'][1], float)
     assert isinstance(sobj[0]['X'][0], np.ndarray)
     assert isinstance(sobj[0]['Y'], np.ndarray)
     assert sobj.__len__() == sobj.param_source.shape[0]
     assert 224 == sobj[0]['X'][0].size
     assert 11 == sobj[0]['Y'].size
Ejemplo n.º 7
0
def test_create_small_dataset():
    filename = get_project_root() / 'data' / 'parameters_base.fits'
    savename = get_project_root() / 'data' / 'small_parameters_base.fits'
    create_small_dataset(filename, savename, size=10000)
    params = fits.open(savename)[0].data
    assert params.shape[0] == 10000
Ejemplo n.º 8
0
def test_download_from_google_disc():
    file_id = '19jkSXHxAPWZvfgo5oxSmvEagme5YLY33'
    dest_path = str(get_project_root() / 'data' / 'small_parameters_base.fits')
    download_from_google_disc(fileid=file_id, dest=dest_path)
    parameter_base = fits.open(dest_path)[0].data
    assert type(parameter_base) == np.ndarray
Ejemplo n.º 9
0
def test_get_project_root():
    root_path = get_project_root()
    root_path1 = Path(__file__).parent.parent
    assert root_path1 == root_path
Ejemplo n.º 10
0
 def test_init_dataset(self):
     project_path = get_project_root()
     filename = project_path / 'data' / 'small_parameters_base.fits'
     param_array = fits.open(filename)[0].data[:10]
     sobj = SpectrumDataset(data_arr=param_array)
     assert True
Ejemplo n.º 11
0
def test_open_param_file():
    project_path = get_project_root()
    real_path = project_path / 'res_experiments' / 'predictions' / '20170905_030404.fits'
    refer, names = open_param_file(real_path)
    assert refer.shape == (512, 485, 11)
    assert len(names) == 11