Esempio n. 1
0
def create_experiment_fn():
    return trainer_lib.create_experiment_fn(
        model_name=FLAGS.model,
        problem_name=FLAGS.problem,
        data_dir=os.path.expanduser(FLAGS.data_dir),
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        min_eval_frequency=FLAGS.local_eval_frequency,
        schedule=FLAGS.schedule,
        eval_throttle_seconds=FLAGS.eval_throttle_seconds,
        export=FLAGS.export_saved_model,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tfdbg=FLAGS.tfdbg,
        use_dbgprofile=FLAGS.dbgprofile,
        eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
        eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
        eval_early_stopping_metric_delta=FLAGS.
        eval_early_stopping_metric_delta,
        eval_early_stopping_metric_minimize=FLAGS.
        eval_early_stopping_metric_minimize,
        eval_timeout_mins=FLAGS.eval_timeout_mins,
        eval_use_test_set=FLAGS.eval_use_test_set,
        use_tpu=FLAGS.use_tpu,
        use_tpu_estimator=FLAGS.use_tpu_estimator,
        use_xla=FLAGS.xla_compile,
        warm_start_from=FLAGS.warm_start_from,
        decode_from_file=FLAGS.decode_from_file,
        decode_to_file=FLAGS.decode_to_file,
        decode_reference=FLAGS.decode_reference,
        std_server_protocol=FLAGS.std_server_protocol)
Esempio n. 2
0
  def train(self, use_tpu=False, schedule="train"):
    """Run training."""

    exp_fn = trainer_lib.create_experiment_fn(self.model_name,
                                              self.problem_name,
                                              self.data_dir,
                                              train_steps=10,
                                              eval_steps=1,
                                              min_eval_frequency=9,
                                              use_tpu=use_tpu,
                                              schedule=schedule)

    run_config = trainer_lib.create_run_config(model_name=self.model_name,
                                               model_dir=self.model_dir,
                                               num_gpus=0,
                                               use_tpu=use_tpu,
                                               schedule=schedule)

    hparams = registry.hparams(self.hparams_set)

    exp = exp_fn(run_config, hparams)

    exp.train()

    self.hparams = hparams

    return hparams
Esempio n. 3
0
def train_supervised(problem, model_name, hparams, data_dir, output_dir,
                     train_steps, eval_steps, local_eval_frequency=None):
  """Train supervised."""
  if local_eval_frequency is None:
    local_eval_frequency = getattr(FLAGS, "local_eval_frequency")

  exp_fn = trainer_lib.create_experiment_fn(
      model_name, problem, data_dir, train_steps, eval_steps,
      min_eval_frequency=local_eval_frequency
  )
  run_config = trainer_lib.create_run_config(model_dir=output_dir)
  exp = exp_fn(run_config, hparams)
  exp.test()
Esempio n. 4
0
 def testExperimentWithClass(self):
     exp_fn = trainer_lib.create_experiment_fn(
         "transformer",
         algorithmic.TinyAlgo(),
         algorithmic.TinyAlgo.data_dir,
         train_steps=1,
         eval_steps=1,
         min_eval_frequency=1,
         use_tpu=False)
     run_config = trainer_lib.create_run_config(
         model_dir=algorithmic.TinyAlgo.data_dir, num_gpus=0, use_tpu=False)
     hparams = registry.hparams("transformer_tiny_tpu")
     exp = exp_fn(run_config, hparams)
     exp.test()
Esempio n. 5
0
def train_supervised(problem, model_name, hparams, data_dir, output_dir,
                     train_steps, eval_steps, local_eval_frequency=None,
                     schedule="continuous_train_and_eval"):
  """Train supervised."""
  if local_eval_frequency is None:
    local_eval_frequency = FLAGS.local_eval_frequency

  exp_fn = trainer_lib.create_experiment_fn(
      model_name, problem, data_dir, train_steps, eval_steps,
      min_eval_frequency=local_eval_frequency
  )
  run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir)
  exp = exp_fn(run_config, hparams)
  getattr(exp, schedule)()
Esempio n. 6
0
 def testExperiment(self):
   exp_fn = trainer_lib.create_experiment_fn(
       "transformer",
       "tiny_algo",
       self.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = trainer_lib.create_run_config(
       model_dir=self.data_dir, num_gpus=0, use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")
   exp = exp_fn(run_config, hparams)
   exp.test()
 def testExperiment(self):
     exp_fn = trainer_lib.create_experiment_fn("transformer",
                                               "tiny_algo",
                                               self.data_dir,
                                               train_steps=1,
                                               eval_steps=1,
                                               min_eval_frequency=1,
                                               use_tpu=False)
     run_config = trainer_lib.create_run_config(model_dir=self.data_dir,
                                                num_gpus=0,
                                                use_tpu=False)
     hparams = registry.hparams("transformer_tiny_tpu")()
     exp = exp_fn(run_config, hparams)
     exp.test()
def train_supervised(problem, model_name, hparams, data_dir, output_dir,
                     train_steps, eval_steps, local_eval_frequency=None,
                     schedule="continuous_train_and_eval"):
  """Train supervised."""
  if local_eval_frequency is None:
    local_eval_frequency = getattr(FLAGS, "local_eval_frequency")

  exp_fn = trainer_lib.create_experiment_fn(
      model_name, problem, data_dir, train_steps, eval_steps,
      min_eval_frequency=local_eval_frequency
  )
  run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir)
  exp = exp_fn(run_config, hparams)
  getattr(exp, schedule)()
Esempio n. 9
0
 def testExperimentWithClass(self):
   exp_fn = trainer_lib.create_experiment_fn(
       "transformer",
       algorithmic.TinyAlgo(),
       algorithmic.TinyAlgo.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = trainer_lib.create_run_config(
       model_name="transformer",
       model_dir=algorithmic.TinyAlgo.data_dir,
       num_gpus=0,
       use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")
   exp = exp_fn(run_config, hparams)
   exp.test()
Esempio n. 10
0
def create_experiment_fn():
  return trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=get_problem_name(),
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu)
Esempio n. 11
0
def create_experiment_fn():
  return trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=get_problem_name(),
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu)
Esempio n. 12
0
def create_experiment_fn(**kwargs):
    return trainer_lib.create_experiment_fn(
        model_name=FLAGS.model,
        problem_name=FLAGS.problem,
        data_dir=os.path.expanduser(FLAGS.data_dir),
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        min_eval_frequency=FLAGS.local_eval_frequency,
        schedule=FLAGS.schedule,
        eval_throttle_seconds=FLAGS.eval_throttle_seconds,
        export=FLAGS.export_saved_model,
        decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
        use_tfdbg=FLAGS.tfdbg,
        use_dbgprofile=FLAGS.dbgprofile,
        eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
        eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
        eval_early_stopping_metric_delta=FLAGS.
        eval_early_stopping_metric_delta,
        eval_early_stopping_metric_minimize=FLAGS.
        eval_early_stopping_metric_minimize,
        use_tpu=FLAGS.use_tpu,
        use_tpu_estimator=FLAGS.use_tpu_estimator,
        use_xla=FLAGS.xla_compile,
        **kwargs)