def create_experiment_fn(): return trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=FLAGS.problem, data_dir=os.path.expanduser(FLAGS.data_dir), train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, schedule=FLAGS.schedule, eval_throttle_seconds=FLAGS.eval_throttle_seconds, export=FLAGS.export_saved_model, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, eval_early_stopping_metric_delta=FLAGS. eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, eval_timeout_mins=FLAGS.eval_timeout_mins, eval_use_test_set=FLAGS.eval_use_test_set, use_tpu=FLAGS.use_tpu, use_tpu_estimator=FLAGS.use_tpu_estimator, use_xla=FLAGS.xla_compile, warm_start_from=FLAGS.warm_start_from, decode_from_file=FLAGS.decode_from_file, decode_to_file=FLAGS.decode_to_file, decode_reference=FLAGS.decode_reference, std_server_protocol=FLAGS.std_server_protocol)
def train(self, use_tpu=False, schedule="train"): """Run training.""" exp_fn = trainer_lib.create_experiment_fn(self.model_name, self.problem_name, self.data_dir, train_steps=10, eval_steps=1, min_eval_frequency=9, use_tpu=use_tpu, schedule=schedule) run_config = trainer_lib.create_run_config(model_name=self.model_name, model_dir=self.model_dir, num_gpus=0, use_tpu=use_tpu, schedule=schedule) hparams = registry.hparams(self.hparams_set) exp = exp_fn(run_config, hparams) exp.train() self.hparams = hparams return hparams
def train_supervised(problem, model_name, hparams, data_dir, output_dir, train_steps, eval_steps, local_eval_frequency=None): """Train supervised.""" if local_eval_frequency is None: local_eval_frequency = getattr(FLAGS, "local_eval_frequency") exp_fn = trainer_lib.create_experiment_fn( model_name, problem, data_dir, train_steps, eval_steps, min_eval_frequency=local_eval_frequency ) run_config = trainer_lib.create_run_config(model_dir=output_dir) exp = exp_fn(run_config, hparams) exp.test()
def testExperimentWithClass(self): exp_fn = trainer_lib.create_experiment_fn( "transformer", algorithmic.TinyAlgo(), algorithmic.TinyAlgo.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config( model_dir=algorithmic.TinyAlgo.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu") exp = exp_fn(run_config, hparams) exp.test()
def train_supervised(problem, model_name, hparams, data_dir, output_dir, train_steps, eval_steps, local_eval_frequency=None, schedule="continuous_train_and_eval"): """Train supervised.""" if local_eval_frequency is None: local_eval_frequency = FLAGS.local_eval_frequency exp_fn = trainer_lib.create_experiment_fn( model_name, problem, data_dir, train_steps, eval_steps, min_eval_frequency=local_eval_frequency ) run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir) exp = exp_fn(run_config, hparams) getattr(exp, schedule)()
def testExperiment(self): exp_fn = trainer_lib.create_experiment_fn( "transformer", "tiny_algo", self.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config( model_dir=self.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu") exp = exp_fn(run_config, hparams) exp.test()
def testExperiment(self): exp_fn = trainer_lib.create_experiment_fn("transformer", "tiny_algo", self.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config(model_dir=self.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu")() exp = exp_fn(run_config, hparams) exp.test()
def train_supervised(problem, model_name, hparams, data_dir, output_dir, train_steps, eval_steps, local_eval_frequency=None, schedule="continuous_train_and_eval"): """Train supervised.""" if local_eval_frequency is None: local_eval_frequency = getattr(FLAGS, "local_eval_frequency") exp_fn = trainer_lib.create_experiment_fn( model_name, problem, data_dir, train_steps, eval_steps, min_eval_frequency=local_eval_frequency ) run_config = trainer_lib.create_run_config(model_name, model_dir=output_dir) exp = exp_fn(run_config, hparams) getattr(exp, schedule)()
def testExperimentWithClass(self): exp_fn = trainer_lib.create_experiment_fn( "transformer", algorithmic.TinyAlgo(), algorithmic.TinyAlgo.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config( model_name="transformer", model_dir=algorithmic.TinyAlgo.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu") exp = exp_fn(run_config, hparams) exp.test()
def create_experiment_fn(): return trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), data_dir=os.path.expanduser(FLAGS.data_dir), train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, schedule=FLAGS.schedule, export=FLAGS.export_saved_model, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu)
def create_experiment_fn(): return trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), data_dir=os.path.expanduser(FLAGS.data_dir), train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, schedule=FLAGS.schedule, export=FLAGS.export_saved_model, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu)
def create_experiment_fn(**kwargs): return trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=FLAGS.problem, data_dir=os.path.expanduser(FLAGS.data_dir), train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, schedule=FLAGS.schedule, eval_throttle_seconds=FLAGS.eval_throttle_seconds, export=FLAGS.export_saved_model, decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, eval_early_stopping_metric_delta=FLAGS. eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu, use_tpu_estimator=FLAGS.use_tpu_estimator, use_xla=FLAGS.xla_compile, **kwargs)