Ejemplo n.º 1
0
def run(callbacks=None):
  # keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

  params = config_factory.config_generator(FLAGS.model)

  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  params.override(
      {
          'strategy_type': FLAGS.strategy_type,
          'model_dir': FLAGS.model_dir,
      },
      is_strict=False)
  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
  logging.info('Model Parameters: {}'.format(params_str))

  train_input_fn = None
  eval_input_fn = None
  training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
  eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
  if not training_file_pattern and not eval_file_pattern:
    raise ValueError('Must provide at least one of training_file_pattern and '
                     'eval_file_pattern.')

  if training_file_pattern:
    # Use global batch size for single host.
    train_input_fn = input_reader.InputFn(
        file_pattern=training_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.TRAIN,
        batch_size=params.train.batch_size)

  if eval_file_pattern:
    eval_input_fn = input_reader.InputFn(
        file_pattern=eval_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.PREDICT_WITH_GT,
        batch_size=params.eval.batch_size,
        num_examples=params.eval.eval_samples)
  # estimator_run(params, train_input_fn)
  return run_executor(
      params,
      mode=ModeKeys.TRAIN,
      train_input_fn=train_input_fn,
      callbacks=callbacks)
Ejemplo n.º 2
0
def _get_params_from_flags(flags_obj: flags.FlagValues):
    """Get ParamsDict from flags."""
    model = flags_obj.model_type.lower()
    dataset = flags_obj.dataset.lower()
    params = configs.get_config(model=model, dataset=dataset)

    flags_overrides = {
        'model_dir': flags_obj.model_dir,
        'mode': flags_obj.mode,
        'model': {
            'name': model,
        },
        'runtime': {
            'run_eagerly': flags_obj.run_eagerly,
            'tpu': flags_obj.tpu,
        },
        'train_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'validation_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'train': {
            'time_history': {
                'log_steps': flags_obj.log_steps,
            },
        },
    }

    overriding_configs = (flags_obj.config_file, flags_obj.params_override,
                          flags_overrides)

    pp = pprint.PrettyPrinter()

    logging.info('Base params: %s', pp.pformat(params.as_dict()))

    for param in overriding_configs:
        logging.info('Overriding params: %s', param)
        # Set is_strict to false because we can have dynamic dict parameters.
        params = params_dict.override_params_dict(params,
                                                  param,
                                                  is_strict=False)

    params.validate()
    params.lock()

    logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
    return params
Ejemplo n.º 3
0
 def test_override_params_dict_using_yaml_string(self):
     params = params_dict.ParamsDict({
         'a': 1,
         'b': 2.5,
         'c': [3, 4],
         'd': 'hello',
         'e': False
     })
     override_yaml_string = "'b': 5.2\n'c': [30, 40]"
     params = params_dict.override_params_dict(params,
                                               override_yaml_string,
                                               is_strict=True)
     self.assertEqual(1, params.a)
     self.assertEqual(5.2, params.b)
     self.assertEqual([30, 40], params.c)
     self.assertEqual('hello', params.d)
     self.assertEqual(False, params.e)
Ejemplo n.º 4
0
 def test_override_params_dict_using_yaml_file(self):
     params = params_dict.ParamsDict({
         'a': 1,
         'b': 2.5,
         'c': [3, 4],
         'd': 'hello',
         'e': False
     })
     override_yaml_file = self.write_temp_file(
         'params.yaml', r"""
     b: 5.2
     c: [30, 40]
     """)
     params = params_dict.override_params_dict(params,
                                               override_yaml_file,
                                               is_strict=True)
     self.assertEqual(1, params.a)
     self.assertEqual(5.2, params.b)
     self.assertEqual([30, 40], params.c)
     self.assertEqual('hello', params.d)
     self.assertEqual(False, params.e)
Ejemplo n.º 5
0
 def test_override_params_dict_using_csv_string(self):
     params = params_dict.ParamsDict({
         'a': 1,
         'b': {
             'b1': 2,
             'b2': [2, 3],
         },
         'd': {
             'd1': {
                 'd2': 'hello'
             }
         },
         'e': False
     })
     override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
     params = params_dict.override_params_dict(params,
                                               override_csv_string,
                                               is_strict=True)
     self.assertEqual(1, params.a)
     self.assertEqual(2, params.b.b1)
     self.assertEqual([3, 4], params.b.b2)
     self.assertEqual('hi, world', params.d.d1.d2)
     self.assertEqual('gs://test', params.e)
Ejemplo n.º 6
0
 def test_override_params_dict_using_json_string(self):
     params = params_dict.ParamsDict({
         'a': 1,
         'b': {
             'b1': 2,
             'b2': [2, 3],
         },
         'd': {
             'd1': {
                 'd2': 'hello'
             }
         },
         'e': False
     })
     override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
     params = params_dict.override_params_dict(params,
                                               override_json_string,
                                               is_strict=True)
     self.assertEqual(1, params.a)
     self.assertEqual(2, params.b.b1)
     self.assertEqual([3, 4], params.b.b2)
     self.assertEqual('hi', params.d.d1.d2)
     self.assertEqual(False, params.e)