def _build_train_spec(self, checkpoint_path): train_hooks = [ hooks.LogParametersCountHook()] if checkpoint_path is not None: train_hooks.append(hooks.LoadWeightsFromCheckpointHook(checkpoint_path)) if self._hvd is not None: train_hooks.append(self._hvd.BroadcastGlobalVariablesHook(0)) train_steps = self._config["train"].get("train_steps") if train_steps is not None and self._hvd is not None: train_steps //= self._hvd.size() train_spec = tf.estimator.TrainSpec( input_fn=estimator_util.make_input_fn( self._model, tf.estimator.ModeKeys.TRAIN, self._config["train"]["batch_size"], features_file=self._config["data"]["train_features_file"], labels_file=self._config["data"].get("train_labels_file"), batch_type=self._config["train"]["batch_type"], batch_multiplier=self._num_devices, bucket_width=self._config["train"]["bucket_width"], maximum_features_length=self._config["train"].get("maximum_features_length"), maximum_labels_length=self._config["train"].get("maximum_labels_length"), shuffle_buffer_size=self._config["train"]["sample_buffer_size"], single_pass=self._config["train"].get("single_pass", False), num_shards=self._hvd.size() if self._hvd is not None else 1, shard_index=self._hvd.rank() if self._hvd is not None else 0, num_threads=self._config["train"].get("num_threads"), prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size"), return_dataset=False), max_steps=train_steps, hooks=train_hooks) return train_spec
def infer(self, features_file, predictions_file=None, checkpoint_path=None, log_time=False): """Runs inference. Args: features_file: The file(s) to infer from. predictions_file: If set, predictions are saved in this file. checkpoint_path: Path of a specific checkpoint to predict. If ``None``, the latest is used. log_time: If ``True``, several time metrics will be printed in the logs at the end of the inference loop. """ if checkpoint_path is not None and tf.gfile.IsDirectory( checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) input_fn = estimator_util.make_input_fn( self._model, tf.estimator.ModeKeys.PREDICT, self._config["infer"]["batch_size"], features_file=features_file, bucket_width=self._config["infer"]["bucket_width"], num_threads=self._config["infer"].get("num_threads"), prefetch_buffer_size=self._config["infer"].get( "prefetch_buffer_size"), return_dataset=False) if predictions_file: stream = io.open(predictions_file, encoding="utf-8", mode="w") else: stream = sys.stdout infer_hooks = [] if log_time: infer_hooks.append(hooks.LogPredictionTimeHook()) ordered_writer = None write_fn = lambda prediction: (self._model.print_prediction( prediction, params=self._config["infer"], stream=stream)) estimator = self._make_estimator() for prediction in estimator.predict(input_fn=input_fn, checkpoint_path=checkpoint_path, hooks=infer_hooks): # If the index is part of the prediction, they may be out of order. if "index" in prediction: if ordered_writer is None: ordered_writer = OrderRestorer( index_fn=lambda prediction: prediction["index"], callback_fn=write_fn) ordered_writer.push(prediction) else: write_fn(prediction) if predictions_file: stream.close()
def _build_eval_spec(self): eval_spec = tf.estimator.EvalSpec( input_fn=estimator_util.make_input_fn( self._model, tf.estimator.ModeKeys.EVAL, self._config["eval"]["batch_size"], features_file=self._config["data"]["eval_features_file"], labels_file=self._config["data"].get("eval_labels_file"), num_threads=self._config["eval"].get("num_threads"), prefetch_buffer_size=self._config["eval"].get("prefetch_buffer_size")), steps=None, exporters=_make_exporters( self._config["eval"]["exporters"], estimator_util.make_serving_input_fn(self._model), assets_extra=self._get_model_assets()), throttle_secs=self._config["eval"]["eval_delay"]) return eval_spec
def input_fn(self, mode, batch_size, metadata=None, features_file=None, labels_file=None, batch_type="examples", batch_multiplier=1, bucket_width=None, single_pass=False, num_threads=None, sample_buffer_size=None, prefetch_buffer_size=None, maximum_features_length=None, maximum_labels_length=None, num_shards=1, shard_index=0): """Returns an input function. Args: mode: A ``tf.estimator.ModeKeys`` mode. batch_size: The batch size to use. metadata: A dictionary containing additional metadata set by the user. Required if ``Model.initialize()`` has not been called. features_file: The file containing input features. labels_file: The file containing output labels. batch_type: The training batching stragety to use: can be "examples" or "tokens". batch_multiplier: The batch size multiplier to prepare splitting accross replicated graph parts. bucket_width: The width of the length buckets to select batch candidates from. ``None`` to not constrain batch formation. single_pass: If ``True``, makes a single pass over the training data. num_threads: The number of elements processed in parallel. sample_buffer_size: The number of elements from which to sample. prefetch_buffer_size: The number of batches to prefetch asynchronously. If ``None``, use an automatically tuned value on TensorFlow 1.8+ and 1 on older versions. maximum_features_length: The maximum length or list of maximum lengths of the features sequence(s). ``None`` to not constrain the length. maximum_labels_length: The maximum length of the labels sequence. ``None`` to not constrain the length. num_shards: The number of data shards (usually the number of workers in a distributed setting). shard_index: The shard index this input pipeline should read from. Returns: A callable that returns the next element. See Also: ``tf.estimator.Estimator``. """ if metadata is not None: self.initialize(metadata) return estimator.make_input_fn( self, mode, batch_size, features_file, labels_file=labels_file, batch_type=batch_type, batch_multiplier=batch_multiplier, bucket_width=bucket_width, maximum_features_length=maximum_features_length, maximum_labels_length=maximum_labels_length, shuffle_buffer_size=sample_buffer_size, single_pass=single_pass, num_shards=num_shards, shard_index=shard_index, num_threads=num_threads, prefetch_buffer_size=prefetch_buffer_size, return_dataset=False)