def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, use_validation_monitor=False, validation_monitor_kwargs=None, use_early_stopping=False, early_stopping_kwargs=None): """Create train and eval hooks for Experiment.""" train_hooks = [] eval_hooks = [] if use_tfdbg: hook = debug.LocalCLIDebugHook() train_hooks.append(hook) eval_hooks.append(hook) if use_dbgprofile: # Recorded traces can be visualized with chrome://tracing/ # The memory/tensor lifetime is also profiled tf.logging.info("Using ProfilerHook") defaults = dict(save_steps=10, show_dataflow=True, show_memory=True) defaults.update(dbgprofile_kwargs) train_hooks.append(tf.train.ProfilerHook(**defaults)) if use_validation_monitor: tf.logging.info("Using ValidationMonitor") train_hooks.append( contrib.learn().monitors.ValidationMonitor( hooks=eval_hooks, **validation_monitor_kwargs)) if use_early_stopping: tf.logging.info("Using EarlyStoppingHook") hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs) # Adding to both training and eval so that eval aborts as well train_hooks.append(hook) eval_hooks.append(hook) return train_hooks, eval_hooks
def create_experiment(run_config, hparams, model_name, problem_name, data_dir, train_steps, eval_steps, min_eval_frequency=2000, eval_throttle_seconds=600, schedule="train_and_evaluate", export=False, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, eval_timeout_mins=240, eval_use_test_set=False, use_tpu=False, use_tpu_estimator=False, use_xla=False, export_saved_model_api_version=1, use_guarantee_const_getter=False, additional_train_hooks=None, additional_eval_hooks=None, warm_start_from=None, decode_from_file="", decode_to_file="", decode_reference="", std_server_protocol=None): """Create Experiment.""" # HParams hparams.add_hparam("model_dir", run_config.model_dir) hparams.add_hparam("data_dir", data_dir) hparams.add_hparam("train_steps", train_steps) hparams.add_hparam("eval_steps", eval_steps) hparams.add_hparam("schedule", schedule) hparams.add_hparam("warm_start_from", warm_start_from) hparams.add_hparam("std_server_protocol", std_server_protocol) hparams.add_hparam("eval_freq_in_steps", min_eval_frequency) hparams.add_hparam("eval_timeout_mins", eval_timeout_mins) if decode_hparams is not None: decode_hparams.add_hparam("decode_from_file", decode_from_file) if decode_to_file and not decode_hparams.decode_to_file: decode_hparams.decode_to_file = decode_to_file if decode_reference and not decode_hparams.decode_reference: decode_hparams.decode_reference = decode_reference add_problem_hparams(hparams, problem_name) # Estimator estimator = create_estimator( model_name, hparams, run_config, schedule=schedule, decode_hparams=decode_hparams, use_tpu=use_tpu, use_tpu_estimator=use_tpu_estimator, use_xla=use_xla, export_saved_model_api_version=export_saved_model_api_version, use_guarantee_const_getter=use_guarantee_const_getter) # Input fns from Problem problem = hparams.problem train_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.TRAIN, hparams) dataset_split = "test" if eval_use_test_set else None dataset_kwargs = {"dataset_split": dataset_split} eval_input_fn = problem.make_estimator_input_fn( tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs) # Export exporter = None if export: def compare_fn(best_eval_result, current_eval_result): metric = eval_early_stopping_metric or "loss" return current_eval_result[metric] < best_eval_result[metric] def serving_input_receiver_fn(hparams, decode_hparams, use_tpu): return problem.serving_input_fn(hparams, decode_hparams, use_tpu) exporter = tf.estimator.BestExporter( name="best", serving_input_receiver_fn=serving_input_receiver_fn, compare_fn=compare_fn, assets_extra=problem.export_assets) # Hooks validation_monitor_kwargs = dict( input_fn=eval_input_fn, eval_steps=eval_steps, every_n_steps=min_eval_frequency, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) dbgprofile_kwargs = {"output_dir": run_config.model_dir} early_stopping_kwargs = dict( events_dir=os.path.join(run_config.model_dir, "eval_continuous"), tag=eval_early_stopping_metric, num_plateau_steps=eval_early_stopping_steps, plateau_decrease=eval_early_stopping_metric_minimize, plateau_delta=eval_early_stopping_metric_delta, every_n_steps=min_eval_frequency) # Eval on TPU Pods is not supported yet if use_tpu and run_config.tpu_config.num_shards > 8 and "eval" in schedule: raise ValueError("Eval is not currently supported on a TPU Pod") # In-process eval (and possible early stopping) if schedule == "continuous_train_and_eval" and min_eval_frequency: tf.logging.warn("ValidationMonitor only works with " "--schedule=train_and_evaluate") use_validation_monitor = (schedule == "train_and_evaluate" and min_eval_frequency) # Distributed early stopping local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] use_early_stopping = (schedule not in local_schedules and eval_early_stopping_steps) train_hooks, eval_hooks = create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, validation_monitor_kwargs=validation_monitor_kwargs, use_early_stopping=use_early_stopping, early_stopping_kwargs=early_stopping_kwargs) hook_context = HookContext(estimator=estimator, problem=problem, hparams=hparams) train_hooks += t2t_model.T2TModel.get_train_hooks(model_name, hook_context) eval_hooks += t2t_model.T2TModel.get_eval_hooks(model_name, hook_context) if additional_train_hooks: train_hooks += additional_train_hooks if additional_eval_hooks: eval_hooks += additional_eval_hooks train_hooks = contrib.learn().monitors.replace_monitors_with_hooks( train_hooks, estimator) eval_hooks = contrib.learn().monitors.replace_monitors_with_hooks( eval_hooks, estimator) train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=train_steps, hooks=train_hooks) eval_spec = tf.estimator.EvalSpec( eval_input_fn, steps=eval_steps, hooks=eval_hooks, start_delay_secs=0 if hparams.schedule == "evaluate" else 120, throttle_secs=eval_throttle_seconds, exporters=exporter) return T2TExperiment(estimator, hparams, train_spec, eval_spec, use_validation_monitor, decode_hparams)
def create_run_config(model_name, master="", model_dir=None, iterations_per_loop=1000, num_shards=8, log_device_placement=False, save_checkpoints_steps=1000, save_checkpoints_secs=None, keep_checkpoint_max=20, keep_checkpoint_every_n_hours=10000, num_gpus=1, gpu_order="", num_async_replicas=1, enable_graph_rewriter=False, gpu_mem_fraction=0.95, no_data_parallelism=False, optionally_use_dist_strat=False, daisy_chain_variables=True, schedule="continuous_train_and_eval", worker_job="/job:localhost", worker_id=0, ps_replicas=0, ps_job="/job:ps", ps_gpu=0, random_seed=None, sync=False, tpu_infeed_sleep_secs=None, use_tpu=False, use_tpu_estimator=False, xla_jit_level=tf.OptimizerOptions.OFF, inter_op_parallelism_threads=0, log_step_count_steps=100, intra_op_parallelism_threads=0, tpu_config_extra_kwargs=None, cloud_tpu_name="", cloud_tpu_zone=None): """Create RunConfig, TPUConfig, and Parallelism object.""" session_config = create_session_config( log_device_placement=log_device_placement, enable_graph_rewriter=enable_graph_rewriter, gpu_mem_fraction=gpu_mem_fraction, use_tpu=use_tpu, xla_jit_level=xla_jit_level, inter_op_parallelism_threads=inter_op_parallelism_threads, intra_op_parallelism_threads=intra_op_parallelism_threads) run_config_args = { "master": master, "evaluation_master": master, "model_dir": model_dir, "session_config": session_config, "save_summary_steps": 100, "save_checkpoints_steps": save_checkpoints_steps, "save_checkpoints_secs": save_checkpoints_secs, "keep_checkpoint_max": keep_checkpoint_max, "keep_checkpoint_every_n_hours": keep_checkpoint_every_n_hours, "tf_random_seed": random_seed, "log_step_count_steps": log_step_count_steps, } if save_checkpoints_secs: del run_config_args["save_checkpoints_steps"] run_config_cls = contrib.learn().RunConfig if use_tpu or use_tpu_estimator: # If using TPUEstimator, use TPU RunConfig, add TPUConfig, and add # additional args. tpu_config_kwargs = { "iterations_per_loop": iterations_per_loop, "num_shards": num_shards, "per_host_input_for_training": True, "initial_infeed_sleep_secs": tpu_infeed_sleep_secs, } if tpu_config_extra_kwargs is not None: tpu_config_kwargs.update(tpu_config_extra_kwargs) run_config_cls = contrib.tpu().RunConfig tpu_config = contrib.tpu().TPUConfig(**tpu_config_kwargs) run_config_args["tpu_config"] = tpu_config if not master and "KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS" in os.environ: # If running on TPU but no master is set and the KUBE env var is present # then we're running on ML Engine. Set the master. run_config_args["master"] = os.environ[ "KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS"] run_config_args["evaluation_master"] = run_config_args["master"] elif not master and cloud_tpu_name: # Update run_config to use cluster instead of master/evaluation_master # as we need the cluster spec to use Cloud Pods tpu_cluster_resolver = contrib.cluster_resolver( ).TPUClusterResolver(tpu=cloud_tpu_name, zone=cloud_tpu_zone) run_config_args["cluster"] = tpu_cluster_resolver del run_config_args["master"] del run_config_args["evaluation_master"] elif is_cloud_async_distributed(): run_config_cls = tf.estimator.RunConfig del run_config_args["master"] del run_config_args["evaluation_master"] # tf.estimator RunConfig construction got totally broken in TF2. # we now have to specify master in a global environment variable if contrib.is_tf2: del run_config_args["evaluation_master"] del run_config_args["master"] config = run_config_cls(**run_config_args) # If not using TPU, add device info for data_parallelism config.use_tpu = use_tpu if not use_tpu: config.t2t_device_info = { "num_async_replicas": num_async_replicas, } use_distribution_strategy = ( optionally_use_dist_strat and t2t_model.T2TModel.has_symmetric_shards(model_name) and not no_data_parallelism and ps_replicas == 0 and ps_gpu == 0 and num_async_replicas == 1) if use_distribution_strategy: tf.logging.info( "Configuring MirroredStrategy DistributionStrategy to replicate the " "model.") distribution = contrib.distribute().MirroredStrategy() config = config.replace(train_distribute=distribution) config.data_parallelism = None else: tf.logging.info( "Configuring DataParallelism to replicate the model.") config.data_parallelism = devices.data_parallelism( daisy_chain_variables=daisy_chain_variables, ps_replicas=ps_replicas, ps_job=ps_job, ps_gpu=ps_gpu, schedule=schedule, sync=sync, worker_gpu=num_gpus, worker_replicas=num_async_replicas, worker_id=worker_id, gpu_order=gpu_order, worker_job=worker_job, no_data_parallelism=no_data_parallelism) return config