def run(callbacks=None): # keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) params = config_factory.config_generator(FLAGS.model) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=True) params = params_dict.override_params_dict( params, FLAGS.params_override, is_strict=True) params.override( { 'strategy_type': FLAGS.strategy_type, 'model_dir': FLAGS.model_dir, }, is_strict=False) params.validate() params.lock() pp = pprint.PrettyPrinter() params_str = pp.pformat(params.as_dict()) logging.info('Model Parameters: {}'.format(params_str)) train_input_fn = None eval_input_fn = None training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern if not training_file_pattern and not eval_file_pattern: raise ValueError('Must provide at least one of training_file_pattern and ' 'eval_file_pattern.') if training_file_pattern: # Use global batch size for single host. train_input_fn = input_reader.InputFn( file_pattern=training_file_pattern, params=params, mode=input_reader.ModeKeys.TRAIN, batch_size=params.train.batch_size) if eval_file_pattern: eval_input_fn = input_reader.InputFn( file_pattern=eval_file_pattern, params=params, mode=input_reader.ModeKeys.PREDICT_WITH_GT, batch_size=params.eval.batch_size, num_examples=params.eval.eval_samples) # estimator_run(params, train_input_fn) return run_executor( params, mode=ModeKeys.TRAIN, train_input_fn=train_input_fn, callbacks=callbacks)
def _get_params_from_flags(flags_obj: flags.FlagValues): """Get ParamsDict from flags.""" model = flags_obj.model_type.lower() dataset = flags_obj.dataset.lower() params = configs.get_config(model=model, dataset=dataset) flags_overrides = { 'model_dir': flags_obj.model_dir, 'mode': flags_obj.mode, 'model': { 'name': model, }, 'runtime': { 'run_eagerly': flags_obj.run_eagerly, 'tpu': flags_obj.tpu, }, 'train_dataset': { 'data_dir': flags_obj.data_dir, }, 'validation_dataset': { 'data_dir': flags_obj.data_dir, }, 'train': { 'time_history': { 'log_steps': flags_obj.log_steps, }, }, } overriding_configs = (flags_obj.config_file, flags_obj.params_override, flags_overrides) pp = pprint.PrettyPrinter() logging.info('Base params: %s', pp.pformat(params.as_dict())) for param in overriding_configs: logging.info('Overriding params: %s', param) # Set is_strict to false because we can have dynamic dict parameters. params = params_dict.override_params_dict(params, param, is_strict=False) params.validate() params.lock() logging.info('Final model parameters: %s', pp.pformat(params.as_dict())) return params
def test_override_params_dict_using_yaml_string(self): params = params_dict.ParamsDict({ 'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False }) override_yaml_string = "'b': 5.2\n'c': [30, 40]" params = params_dict.override_params_dict(params, override_yaml_string, is_strict=True) self.assertEqual(1, params.a) self.assertEqual(5.2, params.b) self.assertEqual([30, 40], params.c) self.assertEqual('hello', params.d) self.assertEqual(False, params.e)
def test_override_params_dict_using_yaml_file(self): params = params_dict.ParamsDict({ 'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False }) override_yaml_file = self.write_temp_file( 'params.yaml', r""" b: 5.2 c: [30, 40] """) params = params_dict.override_params_dict(params, override_yaml_file, is_strict=True) self.assertEqual(1, params.a) self.assertEqual(5.2, params.b) self.assertEqual([30, 40], params.c) self.assertEqual('hello', params.d) self.assertEqual(False, params.e)
def test_override_params_dict_using_csv_string(self): params = params_dict.ParamsDict({ 'a': 1, 'b': { 'b1': 2, 'b2': [2, 3], }, 'd': { 'd1': { 'd2': 'hello' } }, 'e': False }) override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test" params = params_dict.override_params_dict(params, override_csv_string, is_strict=True) self.assertEqual(1, params.a) self.assertEqual(2, params.b.b1) self.assertEqual([3, 4], params.b.b2) self.assertEqual('hi, world', params.d.d1.d2) self.assertEqual('gs://test', params.e)
def test_override_params_dict_using_json_string(self): params = params_dict.ParamsDict({ 'a': 1, 'b': { 'b1': 2, 'b2': [2, 3], }, 'd': { 'd1': { 'd2': 'hello' } }, 'e': False }) override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }" params = params_dict.override_params_dict(params, override_json_string, is_strict=True) self.assertEqual(1, params.a) self.assertEqual(2, params.b.b1) self.assertEqual([3, 4], params.b.b2) self.assertEqual('hi', params.d.d1.d2) self.assertEqual(False, params.e)