コード例 #1
0
    def test_load_config(self):
        # Test that load config returns a config file and fails the config check
        # when it should.
        cfg = model_trainer.load_config(self.root_path + 'test.json')

        del cfg['validation_params']
        temp_cfg_path = self.root_path + 'temp.json'

        with open(temp_cfg_path, 'w') as json_f:
            json.dump(cfg, json_f, indent=4)

        with self.assertRaises(RuntimeError):
            model_trainer.load_config(temp_cfg_path)

        os.remove(temp_cfg_path)
コード例 #2
0
    def setUp(self):
        # Open up the config file.
        # Initialize the class with a test baobab config
        self.root_path = os.path.dirname(
            os.path.abspath(__file__)) + '/test_data/'
        self.config_path = self.root_path + 'test.json'
        self.baobab_cfg_path = self.root_path + 'test_baobab_cfg.py'

        # Initialize the config
        self.cfg = model_trainer.load_config(self.config_path)

        # Also initialize the baobab config
        self.baobab_cfg = configs.BaobabConfig.from_file(self.baobab_cfg_path)

        # A few bnn_inference testing things that need to be used again here.
        self.lens_params = [
            'external_shear_gamma_ext', 'external_shear_psi_ext',
            'lens_mass_center_x', 'lens_mass_center_y', 'lens_mass_e1',
            'lens_mass_e2', 'lens_mass_gamma', 'lens_mass_theta_E'
        ]
        self.normalization_constants_path = self.root_path + 'norms.csv'
        self.tf_record_path = self.root_path + 'tf_record_test'

        # Create the normalization file that would have been made
        # during training.
        self.final_params = [
            'external_shear_g1', 'external_shear_g2', 'lens_mass_center_x',
            'lens_mass_center_y', 'lens_mass_e1', 'lens_mass_e2',
            'lens_mass_gamma', 'lens_mass_theta_E_log'
        ]
        model_trainer.prepare_tf_record(self.cfg, self.root_path,
                                        self.tf_record_path, self.final_params,
                                        'train')
コード例 #3
0
    def test_get_normed_pixel_scale(self):
        # Test if get_normed_pixel scale rescales the pixel_scale as we would
        # expect.
        cfg = model_trainer.load_config(self.root_path + 'test.json')
        # The original pixel scale
        pixel_scale = 0.051
        # Test if normalizing the lens parameters works correctly.
        normalized_param_path = self.root_path + 'normed_metadata.csv'
        normalization_constants_path = self.root_path + 'norms.csv'
        train_or_test = 'train'
        lens_params = [
            'external_shear_gamma_ext', 'external_shear_psi_ext',
            'lens_mass_center_x', 'lens_mass_center_y', 'lens_mass_e1',
            'lens_mass_e2', 'lens_mass_gamma', 'lens_mass_theta_E'
        ]
        lens_params_path = self.root_path + 'metadata.csv'
        data_tools.normalize_lens_parameters(lens_params,
                                             lens_params_path,
                                             normalized_param_path,
                                             normalization_constants_path,
                                             train_or_test=train_or_test)

        # New pixel scale
        normed_pixel_scale = model_trainer.get_normed_pixel_scale(
            cfg, pixel_scale)

        lens_params_csv = pd.read_csv(lens_params_path, index_col=None)
        norm_params_csv = pd.read_csv(normalized_param_path, index_col=None)

        self.assertAlmostEqual(
            np.std(lens_params_csv['lens_mass_center_x'] / pixel_scale -
                   norm_params_csv['lens_mass_center_x'] /
                   normed_pixel_scale['lens_mass_center_x']), 0)
        self.assertAlmostEqual(
            np.std(lens_params_csv['lens_mass_center_y'] / pixel_scale -
                   norm_params_csv['lens_mass_center_y'] /
                   normed_pixel_scale['lens_mass_center_y']), 0)

        # Clean up the file now that we're done
        os.remove(normalized_param_path)
        os.remove(normalization_constants_path)
コード例 #4
0
    def test_model_loss_builder_diag(self):
        # Test that the model and loss returned from model_loss_builder
        # agree with what is expected.
        cfg = model_trainer.load_config(self.root_path + 'test.json')
        cfg['training_params']['dropout_type'] = 'concrete'
        final_params = cfg['training_params']['final_params']
        num_params = len(final_params)

        tf.keras.backend.clear_session()
        gc.collect()

        cfg['training_params']['bnn_type'] = 'diag'
        model, loss = model_trainer.model_loss_builder(cfg)
        y_true = np.ones((1, num_params))
        y_pred = np.ones((1, 2 * num_params))
        yptf = tf.constant(y_pred, dtype=tf.float32)
        yttf = tf.constant(y_true, dtype=tf.float32)

        # Check that the loss function has the right dimensions. More rigerous
        # tests of the loss function can be found in the test_bnn_alexnet.
        loss(yttf, yptf)
        self.assertEqual(len(model.layers), 13)
        self.assertEqual(model.layers[-1].output_shape[-1], y_pred.shape[-1])
コード例 #5
0
# First specify the config path
root_path = os.getcwd()[:-20]
config_path = args.config_path

# We also need the path to the baobab configs for the interim and target omega
interim_baobab_omega_path = os.path.join(
    root_path, 'configs/baobab_configs/train_diagonal.py')
target_ovejero_omega_path = args.target_ovejero_omega_path
target_baobab_omega_path = args.target_baobab_omega_path

test_dataset_path = args.test_dataset_path
test_dataset_tf_record_path = args.test_dataset_tf_record_path

# Check that the config has what you need
cfg = model_trainer.load_config(config_path)

# If we're using the emprical config, we need the transformation dictionary
if 'empirical' in test_dataset_path:
    train_to_test_param_map = dict(
        orig_params=['lens_mass_e1', 'lens_mass_e2'],
        transform_func=ellipticity2phi_q,
        new_params=['lens_mass_phi', 'lens_mass_q'])
    n_walkers = 200
else:
    train_to_test_param_map = None
    n_walkers = 50


# Correct any path issues.
def recursive_str_checker(cfg_dict):