示例#1
0
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)

  # Create an input function that reads the evaluation dataset.
  input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.eval_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.EVAL)

  # Run evaluation. This will log the result to stderr and also write a summary
  # file in the model_dir.
  estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
示例#2
0
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)

  # Create the estimator.
  estimator = estimator_util.create_estimator(
      model_class, config.hparams, model_dir=FLAGS.model_dir)
  
  # Print no. of trainable parameters to console.
  var_names = [v for v in estimator.get_variable_names()]
  n_params = np.sum([len(estimator.get_variable_value(v).flatten()) for v in var_names])
  print("Trainable parameters in model:", int(n_params))
  
  # Create an input function that reads the evaluation dataset.
  input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.eval_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.EVAL)

  # Run evaluation. This will log the result to stderr and also write a summary
  # file in the model_dir.
  estimator_util.evaluate(estimator, input_fn, eval_name=FLAGS.eval_name)
def main(_):
  model_class = models.get_model_class(FLAGS.model)

  # Look up the model configuration.
  assert (FLAGS.config_name is None) != (FLAGS.config_json is None), (
      "Exactly one of --config_name or --config_json is required.")
  config = (
      models.get_model_config(FLAGS.model, FLAGS.config_name)
      if FLAGS.config_name else config_util.parse_json(FLAGS.config_json))

  config = configdict.ConfigDict(config)
  config_util.log_and_save_config(config, FLAGS.model_dir)

  # Create the estimator.
  run_config = tf.estimator.RunConfig(keep_checkpoint_max=1)
  estimator = estimator_util.create_estimator(model_class, config.hparams,
                                              run_config, FLAGS.model_dir)

  # Create an input function that reads the training dataset. We iterate through
  # the dataset once at a time if we are alternating with evaluation, otherwise
  # we iterate infinitely.
  train_input_fn = estimator_util.create_input_fn(
      file_pattern=FLAGS.train_files,
      input_config=config.inputs,
      mode=tf.estimator.ModeKeys.TRAIN,
      shuffle_values_buffer=FLAGS.shuffle_buffer_size,
      repeat=1 if FLAGS.eval_files else None)

  if not FLAGS.eval_files:
    estimator.train(train_input_fn, max_steps=FLAGS.train_steps)
  else:
    eval_input_fn = estimator_util.create_input_fn(
        file_pattern=FLAGS.eval_files,
        input_config=config.inputs,
        mode=tf.estimator.ModeKeys.EVAL)
    eval_args = {
        "val": (eval_input_fn, None)  # eval_name: (input_fn, eval_steps)
    }

    for _ in estimator_runner.continuous_train_and_eval(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_args=eval_args,
        train_steps=FLAGS.train_steps):
      # continuous_train_and_eval() yields evaluation metrics after each
      # training epoch. We don't do anything here.
      pass