Example #1
0
 def generate_batch_spectrum(self, param_vector, noise=True):
     line_vec = (6302.5, 2.5, 1)
     line_arg = 1000 * (np.linspace(6302.0692255, 6303.2544205, 56) -
                        line_vec[0])
     spectrum = me_model(param_vector,
                         line_arg,
                         line_vec,
                         with_ff=True,
                         with_noise=noise)
     cont = param_vector[:, 6] + line_vec[2] * param_vector[:, 7]
     cont = cont * np.amax(spectrum.reshape(-1, 224),
                           axis=1) / self.hps.cont_scale
     cont = torch.from_numpy(cont.reshape(-1, 1)).float().to(self.device)
     y = normalize_output(param_vector,
                          mode=self.hps.mode,
                          logB=self.hps.logB)
     y = torch.from_numpy(y).float().to(self.device)
     if 'rescale' in self.hps.transform_type:
         rescaled = (np.swapaxes(spectrum, 0, 2) *
                     np.array(self.hps.factors).reshape(4, 1, 1)).swapaxes(
                         0, 2)
         if 'mlp' in self.hps.transform_type:
             rescaled = rescaled.reshape(-1, 224, order='F')
         rescaled = torch.from_numpy(rescaled).float().to(self.device)
     else:
         NotImplementedError('Only rescale transform')
     data = {'X': [rescaled, cont], 'Y': y}
     return data
 def test_normalize_output_array_angle(self):
     project_path = get_project_root()
     filename = os.path.join(project_path, 'data/small_parameters_base.fits')
     params = fits.open(filename)[0].data[5:11]
     sample = normalize_output(params, mode='range', logB=True)
     assert sample.min() >= 0
     assert sample.max() <= 1
 def test_normalize_output_array_angle(self):
     project_path = get_project_root()
     filename = os.path.join(project_path, 'data/small_parameters_base.fits')
     params = fits.open(filename)[0].data[5:11]
     sample = normalize_output(params, mode='range', logB=True, angle_transformation=True)
     assert sample.min() >= 0
     assert sample.max() < 1
     assert sample[0, 1] == pytest.approx(np.sin(params[0, 1]*np.pi/180), 0.01)
     assert True
 def test_normalize_output_range_angle(self, sample_from_database):
     y = sample_from_database['Y']
     sample = normalize_output(y, mode='range', logB=True, angle_transformation=True)
     assert sample[0][1] == pytest.approx(1, 0.1)
     assert min(sample) > 0
     assert max(sample) < 1
 def test_normalize_output_range_no_angle(self, sample_from_database):
     y = sample_from_database['Y']
     sample = normalize_output(y, mode='range', logB=True)
     assert min(sample) > 0
     assert max(sample) < 1