Exemple #1
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))
        eval_default = 'eval/basic.gin'
        gin.parse_config_file(eval_default)

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            logging.info('Using operative config: %s', operative_config)
            operative_config = cloud.make_file_paths_local(
                operative_config, GIN_PATH)
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
        gin.parse_config_files_and_bindings(gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
Exemple #2
0
 def test_single_path_in_list_handling(self,
                                       download_from_gstorage_function):
     """Tests that function returns a single-element list if given one."""
     path = cloud.make_file_paths_local(
         ['gs://bucket-name/bucket/dir/some_file.gin'], 'gin/search/path')
     download_from_gstorage_function.assert_called_once()
     self.assertNotIsInstance(path, str)
     self.assertListEqual(path, ['some_file.gin'])
Exemple #3
0
 def test_more_paths_in_list_handling(self,
                                      download_from_gstorage_function):
     """Tests that function handle both local and gstorage paths in one list."""
     paths = cloud.make_file_paths_local([
         'gs://bucket-name/bucket/dir/first_file.gin', 'local_file.gin',
         'gs://bucket-name/bucket/dir/second_file.gin'
     ], 'gin/search/path')
     self.assertEqual(download_from_gstorage_function.call_count, 2)
     download_from_gstorage_function.assert_has_calls([
         mock.call('gs://bucket-name/bucket/dir/first_file.gin', mock.ANY),
         mock.call('gs://bucket-name/bucket/dir/second_file.gin', mock.ANY)
     ])
     self.assertListEqual(
         paths, ['first_file.gin', 'local_file.gin', 'second_file.gin'])
Exemple #4
0
 def test_single_local_path_handling(self, download_from_gstorage_function):
     """Tests that function does nothing if given local file path."""
     path = cloud.make_file_paths_local('local_file.gin', 'gin/search/path')
     download_from_gstorage_function.assert_not_called()
     self.assertEqual(path, 'local_file.gin')
Exemple #5
0
 def test_single_path_handling(self, download_from_gstorage_function):
     """Tests that function returns a single value if given single value."""
     path = cloud.make_file_paths_local(
         'gs://bucket-name/bucket/dir/some_file.gin', 'gin/search/path')
     download_from_gstorage_function.assert_called_once()
     self.assertEqual(path, 'some_file.gin')