def _test_pipeline(self, mode, params=None): """Helper function to test the full model pipeline. """ # Create source and target example source_len = 10 target_len = self.max_decode_length + 10 source = " ".join(np.random.choice(self.vocab_list, source_len)) target = " ".join(np.random.choice(self.vocab_list, target_len)) sources_file, targets_file = test_utils.create_temp_parallel_data( sources=[source], targets=[target]) # Build model graph model = self.create_model(params) data_provider = lambda: data_utils.make_parallel_data_provider( [sources_file.name], [targets_file.name]) input_fn = training_utils.create_input_fn(data_provider, self.batch_size) features, labels = input_fn() fetches = model(features, labels, None, mode) fetches = [_ for _ in fetches if _ is not None] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) with tf.contrib.slim.queues.QueueRunners(sess): fetches_ = sess.run(fetches) sources_file.close() targets_file.close() return model, fetches_
def create_inference_graph(model, input_pipeline, batch_size=32): """Creates a graph to perform inference. Args: task: An `InferenceTask` instance. input_pipeline: An instance of `InputPipeline` that defines how to read and parse data. batch_size: The batch size used for inference Returns: The return value of the model function, typically a tuple of (predictions, loss, train_op). """ # TODO: This doesn't really belong here. # How to get rid of this? if hasattr(model, "use_beam_search"): if model.use_beam_search: tf.logging.info("Setting batch size to 1 for beam search.") batch_size = 1 input_fn = training_utils.create_input_fn(pipeline=input_pipeline, batch_size=batch_size, allow_smaller_final_batch=True) # Build the graph features, labels = input_fn() return model(features=features, labels=labels, params=None)
def create_inference_graph(model, input_pipeline, batch_size=32): """Creates a graph to perform inference. Args: task: An `InferenceTask` instance. input_pipeline: An instance of `InputPipeline` that defines how to read and parse data. batch_size: The batch size used for inference Returns: The return value of the model function, typically a tuple of (predictions, loss, train_op). """ # TODO: This doesn't really belong here. # How to get rid of this? if hasattr(model, "use_beam_search"): if model.use_beam_search: tf.logging.info("Setting batch size to 1 for beam search.") batch_size = 1 input_fn = training_utils.create_input_fn( pipeline=input_pipeline, batch_size=batch_size, allow_smaller_final_batch=True) # Build the graph features, labels = input_fn() return model(features=features, labels=labels, params=None)
def create_inference_graph( model_dir, input_file, batch_size=32, beam_width=None): """Creates a graph to perform inference. Args: model_dir: The output directory passed during training. This directory must contain model checkpoints. input_file: A source input file to read from. batch_size: The batch size used for inference beam_width: The beam width for beam search. If None, no beam search is used. Returns: The return value of the model functions, typically a tuple of (predictions, loss, train_op). """ params_overrides = {} if beam_width is not None: tf.logging.info("Setting batch size to 1 for beam search.") batch_size = 1 params_overrides["inference.beam_search.beam_width"] = beam_width model = load_model(model_dir) data_provider = lambda: data_utils.make_parallel_data_provider( data_sources_source=[input_file], data_sources_target=None, shuffle=False, num_epochs=1) input_fn = training_utils.create_input_fn( data_provider_fn=data_provider, batch_size=batch_size, allow_smaller_final_batch=True) # Build the graph features, labels = input_fn() return model( features=features, labels=labels, params=None, mode=tf.contrib.learn.ModeKeys.INFER)
def _test_with_args(self, **kwargs): """Helper function to test create_input_fn with keyword arguments""" sources_file, targets_file = test_utils.create_temp_parallel_data( sources=["Hello World ."], targets=["Goodbye ."]) data_provider_fn = lambda: data_utils.make_parallel_data_provider( [sources_file.name], [targets_file.name]) input_fn = training_utils.create_input_fn( data_provider_fn=data_provider_fn, **kwargs) features, labels = input_fn() with self.test_session() as sess: with tf.contrib.slim.queues.QueueRunners(sess): features_, labels_ = sess.run([features, labels]) self.assertEqual(set(features_.keys()), set(["source_tokens", "source_len"])) self.assertEqual(set(labels_.keys()), set(["target_tokens", "target_len"]))
def test_copy_gen_model(record_path, vocab_path=None): tf.logging.set_verbosity(tf.logging.INFO) vocab = Vocab(vocab_path) batch_size = 2 # Build model graph mode = tf.contrib.learn.ModeKeys.TRAIN params_ = CopyGenSeq2Seq.default_params().copy() params_.update(TEST_PARAMS) params_.update({ "vocab_source": vocab_path, "vocab_target": vocab_path, }) print(params_) model = CopyGenSeq2Seq(params=params_, mode=mode, vocab_instance=vocab) tf.logging.info(vocab_path) input_pipeline_ = input_pipeline.FeaturedTFRecordInputPipeline(params={ "files": [record_path], "shuffle": True }, mode=mode) input_fn = training_utils.create_input_fn(pipeline=input_pipeline_, batch_size=batch_size) features, labels = input_fn() fetches = model(features, labels, None) fetches = [_ for _ in fetches if _ is not None] from tensorflow.python import debug as tf_debug with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) with tf.contrib.slim.queues.QueueRunners(sess): # sess = tf_debug.LocalCLIDebugWrapperSession(sess) fetches_ = sess.run(fetches) print("yes") return model, fetches_
def test_pipeline(self): # Create source and target example source_len = 10 target_len = self.max_decode_length + 10 source = " ".join(np.random.choice(self.vocab_list, source_len)) target = " ".join(np.random.choice(self.vocab_list, target_len)) tfrecords_file = test_utils.create_temp_tfrecords(source=source, target=target) # Build model graph model = self.create_model() featurizer = model.create_featurizer() data_provider = lambda: inputs.make_data_provider( [tfrecords_file.name]) input_fn = training_utils.create_input_fn(data_provider, featurizer, self.batch_size) features, labels = input_fn() predictions, loss, train_op = model(features, labels, None, tf.contrib.learn.ModeKeys.TRAIN) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.initialize_all_tables()) with tf.contrib.slim.queues.QueueRunners(sess): predictions_, loss_, _ = sess.run( [predictions, loss, train_op]) # We have predictions for each target words and the SEQUENCE_START token. # That's why it's `target_len + 1` max_decode_length = model.params["target.max_seq_len"] expected_decode_len = np.minimum(target_len + 1, max_decode_length) np.testing.assert_array_equal(predictions_["logits"].shape, [ self.batch_size, expected_decode_len, model.target_vocab_info.total_size ]) np.testing.assert_array_equal(predictions_["predictions"].shape, [self.batch_size, expected_decode_len]) self.assertFalse(np.isnan(loss_)) tfrecords_file.close()
def _test_with_args(self, **kwargs): """Helper function to test create_input_fn with keyword arguments""" sources_file, targets_file = test_utils.create_temp_parallel_data( sources=["Hello World ."], targets=["Goodbye ."]) pipeline = input_pipeline.ParallelTextInputPipeline( params={ "source_files": [sources_file.name], "target_files": [targets_file.name] }, mode=tf.contrib.learn.ModeKeys.TRAIN) input_fn = training_utils.create_input_fn(pipeline=pipeline, **kwargs) features, labels = input_fn() with self.test_session() as sess: with tf.contrib.slim.queues.QueueRunners(sess): features_, labels_ = sess.run([features, labels]) self.assertEqual( set(features_.keys()), set(["source_tokens", "source_len"])) self.assertEqual(set(labels_.keys()), set(["target_tokens", "target_len"]))
def test_model(source_path, target_path, vocab_path): tf.logging.set_verbosity(tf.logging.INFO) batch_size = 2 # Build model graph mode = tf.contrib.learn.ModeKeys.TRAIN params_ = AttentionSeq2Seq.default_params().copy() params_.update({ "vocab_source": vocab_path, "vocab_target": vocab_path, }) model = AttentionSeq2Seq(params=params_, mode=mode) tf.logging.info(vocab_path) input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={ "source_files": [source_path], "target_files": [target_path] }, mode=mode) input_fn = training_utils.create_input_fn(pipeline=input_pipeline_, batch_size=batch_size) features, labels = input_fn() fetches = model(features, labels, None) fetches = [_ for _ in fetches if _ is not None] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) with tf.contrib.slim.queues.QueueRunners(sess): fetches_ = sess.run(fetches) return model, fetches_
def _test_pipeline(self, mode, params=None): """Helper function to test the full model pipeline. """ # Create source and target example source_len = self.sequence_length + 5 target_len = self.sequence_length + 10 source = " ".join(np.random.choice(self.vocab_list, source_len)) target = " ".join(np.random.choice(self.vocab_list, target_len)) sources_file, targets_file = test_utils.create_temp_parallel_data( sources=[source], targets=[target]) # Build model graph model = self.create_model(mode, params) input_pipeline_ = input_pipeline.ParallelTextInputPipeline( params={ "source_files": [sources_file.name], "target_files": [targets_file.name] }, mode=mode) input_fn = training_utils.create_input_fn( pipeline=input_pipeline_, batch_size=self.batch_size) features, labels = input_fn() fetches = model(features, labels, None) fetches = [_ for _ in fetches if _ is not None] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) with tf.contrib.slim.queues.QueueRunners(sess): fetches_ = sess.run(fetches) sources_file.close() targets_file.close() return model, fetches_
def create_experiment(output_dir): """ Creates a new Experiment instance. Args: output_dir: Output directory for model checkpoints and summaries. """ config = run_config.RunConfig( tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, gpu_memory_fraction=FLAGS.gpu_memory_fraction) config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth config.tf_config.log_device_placement = FLAGS.log_device_placement train_options = training_utils.TrainOptions( model_class=FLAGS.model, model_params=FLAGS.model_params) # On the main worker, save training options if config.is_chief: gfile.MakeDirs(output_dir) train_options.dump(output_dir) bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Training data input pipeline train_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_train, mode=tf.contrib.learn.ModeKeys.TRAIN) # Create training input function train_input_fn = training_utils.create_input_fn( pipeline=train_input_pipeline, batch_size=FLAGS.batch_size, bucket_boundaries=bucket_boundaries, mode=tf.contrib.learn.ModeKeys.TRAIN) # Development data input pipeline dev_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_dev, mode=tf.contrib.learn.ModeKeys.EVAL, shuffle=False, num_epochs=1) # Create eval input function eval_input_fn = training_utils.create_input_fn( pipeline=dev_input_pipeline, batch_size=FLAGS.batch_size, allow_smaller_final_batch=True, mode=tf.contrib.learn.ModeKeys.EVAL) def model_fn(features, labels, params, mode): """Builds the model graph""" model = _create_from_dict( { "class": train_options.model_class, "params": train_options.model_params }, models, mode=mode) return model(features, labels, params) estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=output_dir, config=config, params=FLAGS.model_params) # Create hooks train_hooks = [] for dict_ in FLAGS.hooks: hook = _create_from_dict(dict_, hooks, model_dir=estimator.model_dir, run_config=config) train_hooks.append(hook) # Create metrics eval_metrics = {} for dict_ in FLAGS.metrics: metric = _create_from_dict(dict_, metric_specs) eval_metrics[metric.name] = metric experiment = PatchedExperiment(estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, min_eval_frequency=FLAGS.eval_every_n_steps, train_steps=FLAGS.train_steps, eval_steps=None, eval_metrics=eval_metrics, train_monitors=train_hooks) return experiment
def create_experiment(output_dir): """ Creates a new Experiment instance. Args: output_dir: Output directory for model checkpoints and summaries. """ # Load vocabulary info source_vocab_info = inputs.get_vocab_info(FLAGS.vocab_source) target_vocab_info = inputs.get_vocab_info(FLAGS.vocab_target) # Create data providers train_data_provider = lambda: inputs.make_data_provider([FLAGS.data_train]) dev_data_provider = lambda: inputs.make_data_provider([FLAGS.data_dev]) # Find model class model_class = getattr(models, FLAGS.model) # Parse parameter and merge with defaults hparams = model_class.default_params() if FLAGS.hparams is not None: hparams = HParamsParser(hparams).parse(FLAGS.hparams) # Print hyperparameter values tf.logging.info("Model Hyperparameters") tf.logging.info("=" * 50) for param, value in sorted(hparams.items()): tf.logging.info("%s=%s", param, value) tf.logging.info("=" * 50) # Create model model = model_class(source_vocab_info=source_vocab_info, target_vocab_info=target_vocab_info, params=hparams) featurizer = model.create_featurizer() bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Create input functions train_input_fn = training_utils.create_input_fn( train_data_provider, featurizer, FLAGS.batch_size, bucket_boundaries=bucket_boundaries) eval_input_fn = training_utils.create_input_fn(dev_data_provider, featurizer, FLAGS.batch_size) def model_fn(features, labels, params, mode): """Builds the model graph""" return model(features, labels, params, mode) estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn, model_dir=output_dir) # Create training Hooks model_analysis_hook = hooks.PrintModelAnalysisHook( filename=os.path.join(estimator.model_dir, "model_analysis.txt")) train_sample_hook = hooks.TrainSampleHook( every_n_steps=FLAGS.sample_every_n_steps) metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join( estimator.model_dir, "metadata"), step=10) train_monitors = [model_analysis_hook, train_sample_hook, metadata_hook] experiment = tf.contrib.learn.experiment.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, min_eval_frequency=FLAGS.eval_every_n_steps, train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, train_monitors=train_monitors) return experiment
def test_copy_gen_model(source_path=None, target_path=None, vocab_path=None): tf.logging.set_verbosity(tf.logging.INFO) batch_size = 2 input_depth = 4 sequence_length = 10 if vocab_path is None: # Create vocabulary vocab_list = [str(_) for _ in range(10)] vocab_list += ["笑う", "泣く", "了解", "はい", "^_^"] vocab_size = len(vocab_list) vocab_file = test_utils.create_temporary_vocab_file(vocab_list) vocab_info = vocab.get_vocab_info(vocab_file.name) vocab_path = vocab_file.name tf.logging.info(vocab_file.name) else: vocab_info = vocab.get_vocab_info(vocab_path) vocab_list = get_vocab_list(vocab_path) extend_vocab = vocab_list + ["中国", "爱", "你"] tf.contrib.framework.get_or_create_global_step() source_len = sequence_length + 5 target_len = sequence_length + 10 source = " ".join(np.random.choice(extend_vocab, source_len)) target = " ".join(np.random.choice(extend_vocab, target_len)) is_tmp_file = False if source_path is None and target_path is None: is_tmp_file = True sources_file, targets_file = test_utils.create_temp_parallel_data( sources=[source], targets=[target]) source_path = sources_file.name target_path = targets_file.name # Build model graph mode = tf.contrib.learn.ModeKeys.TRAIN params_ = CopyGenSeq2Seq.default_params().copy() params_.update({ "vocab_source": vocab_path, "vocab_target": vocab_path, }) model = CopyGenSeq2Seq(params=params_, mode=mode) tf.logging.info(source_path) tf.logging.info(target_path) input_pipeline_ = input_pipeline.ParallelTextInputPipeline(params={ "source_files": [source_path], "target_files": [target_path] }, mode=mode) input_fn = training_utils.create_input_fn(pipeline=input_pipeline_, batch_size=batch_size) features, labels = input_fn() fetches = model(features, labels, None) fetches = [_ for _ in fetches if _ is not None] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) with tf.contrib.slim.queues.QueueRunners(sess): fetches_ = sess.run(fetches) if is_tmp_file: sources_file.close() targets_file.close() return model, fetches_
def create_experiment(output_dir): """ Creates a new Experiment instance. Args: output_dir: Output directory for model checkpoints and summaries. """ config = run_config.RunConfig( tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours) # Load vocabulary info source_vocab_info = vocab.get_vocab_info(FLAGS.vocab_source) target_vocab_info = vocab.get_vocab_info(FLAGS.vocab_target) # Find model class model_class = getattr(models, FLAGS.model) # Parse parameter and merge with defaults hparams = model_class.default_params() if FLAGS.hparams is not None and isinstance(FLAGS.hparams, str): hparams = HParamsParser(hparams).parse(FLAGS.hparams) elif isinstance(FLAGS.hparams, dict): hparams.update(FLAGS.hparams) # Print hparams training_utils.print_hparams(hparams) # One the main worker, save training options and vocabulary if config.is_chief: # Copy vocabulary to output directory gfile.MakeDirs(output_dir) source_vocab_path = os.path.join(output_dir, "vocab_source") gfile.Copy(FLAGS.vocab_source, source_vocab_path, overwrite=True) target_vocab_path = os.path.join(output_dir, "vocab_target") gfile.Copy(FLAGS.vocab_target, target_vocab_path, overwrite=True) # Save train options train_options = training_utils.TrainOptions( hparams=hparams, model_class=FLAGS.model, source_vocab_path=source_vocab_path, target_vocab_path=target_vocab_path) train_options.dump(output_dir) # Create model model = model_class(source_vocab_info=source_vocab_info, target_vocab_info=target_vocab_info, params=hparams) bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Create training input function train_input_fn = training_utils.create_input_fn( data_provider_fn=functools.partial( data_utils.make_parallel_data_provider, data_sources_source=FLAGS.train_source, data_sources_target=FLAGS.train_target, shuffle=True, num_epochs=FLAGS.train_epochs, delimiter=FLAGS.delimiter), batch_size=FLAGS.batch_size, bucket_boundaries=bucket_boundaries) # Create eval input function eval_input_fn = training_utils.create_input_fn( data_provider_fn=functools.partial( data_utils.make_parallel_data_provider, data_sources_source=FLAGS.dev_source, data_sources_target=FLAGS.dev_target, shuffle=False, num_epochs=1, delimiter=FLAGS.delimiter), batch_size=FLAGS.batch_size) def model_fn(features, labels, params, mode): """Builds the model graph""" return model(features, labels, params, mode) estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn, model_dir=output_dir, config=config) train_hooks = training_utils.create_default_training_hooks( estimator=estimator, sample_frequency=FLAGS.sample_every_n_steps, delimiter=FLAGS.delimiter) eval_metrics = { "log_perplexity": metrics.streaming_log_perplexity(), "bleu": metrics.make_bleu_metric_spec(), } experiment = tf.contrib.learn.experiment.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, min_eval_frequency=FLAGS.eval_every_n_steps, train_steps=FLAGS.train_steps, eval_steps=None, eval_metrics=eval_metrics, train_monitors=train_hooks) return experiment
def create_experiment(output_dir): """ Creates a new Experiment instance. Args: output_dir: Output directory for model checkpoints and summaries. """ config = run_config.RunConfig( tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_checkpoints_steps=FLAGS.save_checkpoints_steps, keep_checkpoint_max=FLAGS.keep_checkpoint_max, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, gpu_memory_fraction=FLAGS.gpu_memory_fraction) config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth config.tf_config.log_device_placement = FLAGS.log_device_placement train_options = training_utils.TrainOptions( model_class=FLAGS.model, model_params=FLAGS.model_params) # On the main worker, save training options if config.is_chief: gfile.MakeDirs(output_dir) train_options.dump(output_dir) bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Training data input pipeline train_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_train, mode=tf.contrib.learn.ModeKeys.TRAIN) # Create training input function train_input_fn = training_utils.create_input_fn( pipeline=train_input_pipeline, batch_size=FLAGS.batch_size, bucket_boundaries=bucket_boundaries, scope="train_input_fn") # Development data input pipeline dev_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_dev, mode=tf.contrib.learn.ModeKeys.EVAL, shuffle=False, num_epochs=1) # Create eval input function eval_input_fn = training_utils.create_input_fn( pipeline=dev_input_pipeline, batch_size=FLAGS.batch_size, allow_smaller_final_batch=True, scope="dev_input_fn") def model_fn(features, labels, params, mode): """Builds the model graph""" model = _create_from_dict({ "class": train_options.model_class, "params": train_options.model_params }, models, mode=mode) return model(features, labels, params) estimator = tf.contrib.learn.Estimator( model_fn=model_fn, model_dir=output_dir, config=config, params=FLAGS.model_params) # Create hooks train_hooks = [] for dict_ in FLAGS.hooks: hook = _create_from_dict( dict_, hooks, model_dir=estimator.model_dir, run_config=config) train_hooks.append(hook) # Create metrics eval_metrics = {} for dict_ in FLAGS.metrics: metric = _create_from_dict(dict_, metric_specs) eval_metrics[metric.name] = metric experiment = PatchedExperiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, min_eval_frequency=FLAGS.eval_every_n_steps, train_steps=FLAGS.train_steps, eval_steps=None, eval_metrics=eval_metrics, train_monitors=train_hooks) return experiment
def create_estimator_and_specs(output_dir): sessionConfig = tf.ConfigProto(log_device_placement=True, allow_soft_placement=True) sessionConfig.gpu_options.allow_growth = FLAGS.gpu_allow_growth sessionConfig.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction config = tf.estimator.RunConfig( tf_random_seed=FLAGS.tf_random_seed, save_checkpoints_secs=FLAGS.save_checkpoints_secs, save_checkpoints_steps=FLAGS.save_checkpoints_steps, session_config=sessionConfig, keep_checkpoint_max=FLAGS.keep_checkpoint_max, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours) train_options = training_utils.TrainOptions( model_class=FLAGS.model, model_params=FLAGS.model_params) # On the main worker, save training options if config.is_chief: gfile.MakeDirs(output_dir) train_options.dump(output_dir) bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Training data input pipeline train_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_train, mode=tf.contrib.learn.ModeKeys.TRAIN) # Create training input function train_input_fn = training_utils.create_input_fn( pipeline=train_input_pipeline, batch_size=FLAGS.batch_size, bucket_boundaries=bucket_boundaries, scope="train_input_fn") # Development data input pipeline dev_input_pipeline = input_pipeline.make_input_pipeline_from_def( def_dict=FLAGS.input_pipeline_dev, mode=tf.contrib.learn.ModeKeys.EVAL, shuffle=False, num_epochs=1) # Create eval input function eval_input_fn = training_utils.create_input_fn( pipeline=dev_input_pipeline, batch_size=FLAGS.batch_size, allow_smaller_final_batch=True, scope="dev_input_fn") def model_fn(features, labels, params, mode): """Builds the model graph""" model = _create_from_dict( { "class": train_options.model_class, "params": train_options.model_params }, models, mode=mode) (predictions, loss, train_op) = model(features, labels, params) # Create metrics eval_metrics = {} for dict_ in FLAGS.metrics: metric = _create_from_dict(dict_, metric_specs) eval_metrics[metric.name] = metric(features, labels, predictions) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=eval_metrics) estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=output_dir, config=config, params=FLAGS.model_params) # Create hooks train_hooks = [] for dict_ in FLAGS.hooks: hook = _create_from_dict(dict_, hooks, model_dir=estimator.model_dir, run_config=config) train_hooks.append(hook) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.train_steps, hooks=train_hooks) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) return (estimator, train_spec, eval_spec)