def create_experiment( run_config, hparams, model_name, problem_name, data_dir, train_steps, eval_steps, min_eval_frequency=2000, eval_throttle_seconds=600, schedule="train_and_evaluate", export=False, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, eval_timeout_mins=240, use_tpu=False, use_tpu_estimator=False, use_xla=False, additional_train_hooks=None, additional_eval_hooks=None, warm_start_from=None, decode_from_file=None, decode_to_file=None, decode_reference=None, std_server_protocol=None): """Create Experiment.""" # HParams hparams.add_hparam("model_dir", run_config.model_dir) hparams.add_hparam("data_dir", data_dir) hparams.add_hparam("train_steps", train_steps) hparams.add_hparam("eval_steps", eval_steps) hparams.add_hparam("schedule", schedule) hparams.add_hparam("warm_start_from", warm_start_from) hparams.add_hparam("std_server_protocol", std_server_protocol) hparams.add_hparam("eval_freq_in_steps", min_eval_frequency) hparams.add_hparam("eval_timeout_mins", eval_timeout_mins) if decode_hparams is not None: decode_hparams.add_hparam("decode_from_file", decode_from_file) decode_hparams.add_hparam("decode_to_file", decode_to_file) decode_hparams.add_hparam("decode_reference", decode_reference) trainer_lib.add_problem_hparams(hparams, problem_name) # Estimator estimator = trainer_lib.create_estimator( model_name, hparams, run_config, schedule=schedule, decode_hparams=decode_hparams, use_tpu=use_tpu, use_tpu_estimator=use_tpu_estimator, use_xla=use_xla) # Input fns from Problem problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams, dataset_kwargs={"max_records": FLAGS.train_data_size}) eval_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams) # Export exporter = None if export: def compare_fn(best_eval_result, current_eval_result): metric = eval_early_stopping_metric or "loss" return current_eval_result[metric] < best_eval_result[metric] exporter = tf.estimator.BestExporter( name="best", serving_input_receiver_fn=lambda: problem.serving_input_fn(hparams), compare_fn=compare_fn, assets_extra=problem.export_assets) # Hooks validation_monitor_kwargs = dict( input_fn=eval_input_fn, eval_steps=eval_steps, every_n_steps=min_eval_frequency, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) dbgprofile_kwargs = {"output_dir": run_config.model_dir} early_stopping_kwargs = dict( events_dir=os.path.join(run_config.model_dir, "eval_continuous"), tag=eval_early_stopping_metric, num_plateau_steps=eval_early_stopping_steps, plateau_decrease=eval_early_stopping_metric_minimize, plateau_delta=eval_early_stopping_metric_delta, every_n_steps=min_eval_frequency) # Eval on TPU Pods is not supported yet if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule: raise ValueError("Eval is not currently supported on a TPU Pod") # In-process eval (and possible early stopping) if schedule == "continuous_train_and_eval" and min_eval_frequency: tf.logging.warn("ValidationMonitor only works with " "--schedule=train_and_evaluate") use_validation_monitor = ( schedule == "train_and_evaluate" and min_eval_frequency) # Distributed early stopping local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] use_early_stopping = ( schedule not in local_schedules and eval_early_stopping_steps) train_hooks, eval_hooks = trainer_lib.create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, validation_monitor_kwargs=validation_monitor_kwargs, use_early_stopping=use_early_stopping, early_stopping_kwargs=early_stopping_kwargs) hook_context = trainer_lib.HookContext( estimator=estimator, problem=problem, hparams=hparams) train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context) eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context) if additional_train_hooks: train_hooks += additional_train_hooks if additional_eval_hooks: eval_hooks += additional_eval_hooks train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( train_hooks, estimator) eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( eval_hooks, estimator) train_spec = tf.estimator.TrainSpec( train_input_fn, max_steps=train_steps, hooks=train_hooks) eval_spec = tf.estimator.EvalSpec( eval_input_fn, steps=eval_steps, hooks=eval_hooks, start_delay_secs=0 if hparams.schedule == "evaluate" else 120, throttle_secs=eval_throttle_seconds, exporters=exporter) return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec, use_validation_monitor, decode_hparams)
def run_std_server(): exp = trainer_lib.T2TExperiment(*([None] * 5)) exp.run_std_server()
def create_experiment(run_config, hparams, model_name, params, problem_instance, data_dir, train_steps, eval_steps, min_eval_frequency=2000, eval_throttle_seconds=600, schedule="train_and_evaluate", export=False, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, use_validation_monitor=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, autotune=False, use_tpu=False): """Create Experiment.""" # HParams hparams.add_hparam('model_dir', params.model_dir) hparams.add_hparam("data_dir", data_dir) hparams.add_hparam("train_steps", train_steps) hparams.add_hparam("eval_steps", eval_steps) hparams.add_hparam("schedule", schedule) add_problem_hparams(hparams, problem_instance) # Estimator estimator = trainer_lib.create_estimator(model_name, hparams, run_config, schedule=schedule, decode_hparams=decode_hparams, use_tpu=use_tpu) # Input fns from Problem problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams) eval_input_fn = problem.make_estimator_input_fn(tf.estimator.ModeKeys.EVAL, hparams) # Export if export: tf.logging.warn("Exporting from the trainer is deprecated. " "See serving/export.py.") # Hooks validation_monitor_kwargs = dict( input_fn=eval_input_fn, eval_steps=eval_steps, every_n_steps=min_eval_frequency, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) dbgprofile_kwargs = {"output_dir": run_config.model_dir} early_stopping_kwargs = dict( events_dir=os.path.join(run_config.model_dir, "eval_continuous"), tag=eval_early_stopping_metric, num_plateau_steps=eval_early_stopping_steps, plateau_decrease=eval_early_stopping_metric_minimize, plateau_delta=eval_early_stopping_metric_delta, every_n_steps=min_eval_frequency) # In-process eval (and possible early stopping) if schedule == "continuous_train_and_eval" and min_eval_frequency: tf.logging.warn("ValidationMonitor only works with " "--schedule=train_and_evaluate") use_validation_monitor = (schedule == "train_and_evaluate" and min_eval_frequency) # Distributed early stopping local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] use_early_stopping = (schedule not in local_schedules and eval_early_stopping_steps) train_hooks, eval_hooks = trainer_lib.create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, validation_monitor_kwargs=validation_monitor_kwargs, use_early_stopping=use_early_stopping, early_stopping_kwargs=early_stopping_kwargs) train_hooks += t2t_model.T2TModel.get_train_hooks(model_name) eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name) train_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( train_hooks, estimator) eval_hooks = tf.contrib.learn.monitors.replace_monitors_with_hooks( eval_hooks, estimator) train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_steps, hooks=train_hooks) eval_spec = tf.estimator.EvalSpec( eval_input_fn, steps=eval_steps, hooks=eval_hooks, start_delay_secs=0 if hparams.schedule == "evaluate" else 120, throttle_secs=eval_throttle_seconds) if autotune: hooks_kwargs = { "train_monitors": train_hooks, "eval_hooks": eval_hooks } return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, train_steps=train_steps, eval_steps=eval_steps, min_eval_frequency=min_eval_frequency, train_steps_per_iteration=min(min_eval_frequency, train_steps), eval_delay_secs=0 if schedule == "evaluate" else 120, **hooks_kwargs if not use_tpu else {}) return trainer_lib.T2TExperiment(estimator, hparams, train_spec, eval_spec, use_validation_monitor, decode_hparams)