def parse_gin(restore_dir): """Parse gin config from --gin_file, --gin_param, and the model directory.""" # Add user folders to the gin search path. for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: gin.add_config_file_search_path(gin_search_path) # Parse gin configs, later calls override earlier ones. with gin.unlock_config(): # Optimization defaults. use_tpu = bool(FLAGS.tpu) opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' gin.parse_config_file(os.path.join('optimization', opt_default)) eval_default = 'eval/basic.gin' gin.parse_config_file(eval_default) # Load operative_config if it exists (model has already trained). operative_config = train_util.get_latest_operative_config(restore_dir) if tf.io.gfile.exists(operative_config): logging.info('Using operative config: %s', operative_config) operative_config = cloud.make_file_paths_local( operative_config, GIN_PATH) gin.parse_config_file(operative_config, skip_unknown=True) # User gin config and user hyperparameters from flags. gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH) gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param, skip_unknown=True)
def test_single_path_in_list_handling(self, download_from_gstorage_function): """Tests that function returns a single-element list if given one.""" path = cloud.make_file_paths_local( ['gs://bucket-name/bucket/dir/some_file.gin'], 'gin/search/path') download_from_gstorage_function.assert_called_once() self.assertNotIsInstance(path, str) self.assertListEqual(path, ['some_file.gin'])
def test_more_paths_in_list_handling(self, download_from_gstorage_function): """Tests that function handle both local and gstorage paths in one list.""" paths = cloud.make_file_paths_local([ 'gs://bucket-name/bucket/dir/first_file.gin', 'local_file.gin', 'gs://bucket-name/bucket/dir/second_file.gin' ], 'gin/search/path') self.assertEqual(download_from_gstorage_function.call_count, 2) download_from_gstorage_function.assert_has_calls([ mock.call('gs://bucket-name/bucket/dir/first_file.gin', mock.ANY), mock.call('gs://bucket-name/bucket/dir/second_file.gin', mock.ANY) ]) self.assertListEqual( paths, ['first_file.gin', 'local_file.gin', 'second_file.gin'])
def test_single_local_path_handling(self, download_from_gstorage_function): """Tests that function does nothing if given local file path.""" path = cloud.make_file_paths_local('local_file.gin', 'gin/search/path') download_from_gstorage_function.assert_not_called() self.assertEqual(path, 'local_file.gin')
def test_single_path_handling(self, download_from_gstorage_function): """Tests that function returns a single value if given single value.""" path = cloud.make_file_paths_local( 'gs://bucket-name/bucket/dir/some_file.gin', 'gin/search/path') download_from_gstorage_function.assert_called_once() self.assertEqual(path, 'some_file.gin')