Пример #1
0
    def test_run_std_server(self, mock_server):
        # Arrange.
        tf_config = {
            'cluster': self._cluster_spec(),
            'task': {
                'type': run_config_lib.TaskType.PS,
                'index': 1
            }
        }
        with test.mock.patch.dict('os.environ',
                                  {'TF_CONFIG': json.dumps(tf_config)}):
            config = RunConfig(
                master='host2:2222',
                num_cores=15,
                gpu_memory_fraction=0.314,
            )
        for est in self._estimators_for_tests(config):
            ex = Experiment(est,
                            train_input_fn='train_input',
                            eval_input_fn='eval_input')

            # Act.
            ex.run_std_server()

            # Assert.
            mock_server.assert_has_calls(
                [test.mock.call().start(),
                 test.mock.call().join()])
Пример #2
0
 def __init__(self, config=None, max_evals=5, eval_dict=None):
     self.eval_count = 0
     self.fit_count = 0
     self._max_evals = max_evals
     self.export_count = 0
     self._config = config or RunConfig()
     self._model_dir = tempfile.mkdtemp()
     self._eval_dict = eval_dict
     tf_logging.info('Create Core Estimator')
Пример #3
0
    def __init__(self,
                 model_fn=None,
                 model_dir=None,
                 config=None,
                 params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config
        logging.info("Using config: {}".format(vars(self._config)))

        if (model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir(
        )
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(
                allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}
Пример #4
0
    def test_train_default_delay(self):
        for task_id in [0, 1, 3]:
            tf_config = {'task': {'index': task_id}}
            with test.mock.patch.dict('os.environ',
                                      {'TF_CONFIG': json.dumps(tf_config)}):
                config = RunConfig()
            for est in self._estimators_for_tests(config):
                ex = Experiment(est,
                                train_input_fn='train_input',
                                eval_input_fn='eval_input')

                sheep = SheepCounter()
                with test.mock.patch.object(time, 'time', sheep.time):
                    with test.mock.patch.object(time, 'sleep', sheep.sleep):
                        ex.train()
                        self.assertAlmostEqual(task_id * 5,
                                               sheep.time(),
                                               delta=1e-4)
Пример #5
0
    def __init__(self, model_fn, model_dir=None, config=None, params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config

        if(model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir()
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)
        logging.info("Using config: {}".format(vars(self._config)))

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}
Пример #6
0
class Estimator(object):
    """Estimator class is the basic TensorFlow model trainer/evaluator.

    Constructs an `Estimator` instance.

    Args:
        model_fn: Model function. Follows the signature:
            * Args:
                * `features`: single `Tensor` or `dict` of `Tensor`s
                     (depending on data passed to `fit`),
                * `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head models).
                    If mode is `ModeKeys.PREDICT`, `labels=None` will be passed.
                    If the `model_fn`'s signature does not accept `mode`,
                    the `model_fn` must still be able to handle `labels=None`.
                * `mode`: Specifies if this training, evaluation or prediction. See `ModeKeys`.
                * `params`: Optional `dict` of hyperparameters.  Will receive what
                    is passed to Estimator in `params` parameter. This allows
                    to configure Estimators from hyper parameter tuning.
                * `config`: Optional configuration object. Will receive what is passed
                    to Estimator in `config` parameter, or the default `config`.
                    Allows updating things in your model_fn based on configuration
                    such as `num_ps_replicas`.
                * `model_dir`: Optional directory where model parameters, graph etc
                    are saved. Will receive what is passed to Estimator in
                    `model_dir` parameter, or the default `model_dir`. Allows
                    updating things in your model_fn that expect model_dir, such as
                    training hooks.

            * Returns:
               `EstimatorSpec`

            Supports next three signatures for the function:

                * `(features, labels, mode)`
                * `(features, labels, mode, params)`
                * `(features, labels, mode, params, config)`
                * `(features, labels, mode, params, config, model_dir)`

        model_dir: Directory to save model parameters, graph and etc. This can
            also be used to load checkpoints from the directory into a estimator to
            continue training a previously saved model.
        config: Configuration object.
        params: `dict` of hyper parameters that will be passed into `model_fn`.
                  Keys are names of parameters, values are basic python types.
    Raises:
        ValueError: parameters of `model_fn` don't match `params`.
    """
    def __init__(self,
                 model_fn=None,
                 model_dir=None,
                 config=None,
                 params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config
        logging.info("Using config: {}".format(vars(self._config)))

        if (model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir(
        )
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(
                allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}

    @property
    def model_dir(self):
        return self._model_dir

    @property
    def config(self):
        return copy.deepcopy(self._config)

    @property
    def params(self):
        return copy.deepcopy(self._params)

    @staticmethod
    def _verify_model_fn_args(model_fn, params):
        """Verifies model fn arguments."""

        MODEL_FN_ARGS = {'features', 'labels', 'mode', 'params', 'config'}

        if model_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(model_fn)
            if 'features' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include features argument.'.format(
                        model_fn))
            if 'labels' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include labels argument.'.format(
                        model_fn))

            if params is not None and 'params' not in model_fn_args:
                raise ValueError(
                    "Estimator's model_fn `{}` does not include params argument, "
                    "but params `{}` are passed.".format(model_fn, params))
            if params is None and 'params' in model_fn_args:
                logging.warning(
                    "Estimator's model_fn (%s) includes params "
                    "argument, but params are not passed to Estimator.",
                    model_fn)
        else:
            raise ValueError("`model_fn` must be provided to Estimator.")

        if 'self' in model_fn_args:
            model_fn_args.remove('self')

        non_valid_args = set(model_fn_args) - MODEL_FN_ARGS
        if non_valid_args:
            raise ValueError(
                "model_fn `{}` has following not expected args: {}".format(
                    model_fn, non_valid_args))

    def _call_model_fn(self, features, labels, mode):
        """Calls model function with support of 2, 3 or 4 arguments.

        Args:
            features: features dict.
            labels: labels dict.
            mode: ModeKeys

        Returns:
            A `ModelFnOps` object.
            If model_fn returns a tuple, wraps them up in a `ModelFnOps` object.

        Raises:
            ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features,
                                          labels=labels,
                                          **kwargs)

        if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results

    def export_savedmodel(self,
                          export_dir_base,
                          serving_input_receiver_fn,
                          assets_extra=None,
                          as_text=False,
                          checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.
        This method builds a new graph by first calling the
        serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
        this `Estimator`'s model_fn to generate the model graph based on those
        features. It restores the given checkpoint (or, lacking that, the most
        recent checkpoint) into this graph in a fresh session.  Finally it creates
        a timestamped export directory below the given export_dir_base, and writes
        a `SavedModel` into it containing a single `MetaGraphDef` saved from this
        session.
        The exported `MetaGraphDef` will provide one `SignatureDef` for each
        element of the export_outputs dict returned from the model_fn, named using
        the same keys.  One of these keys is always
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
        signature will be served when a serving request does not specify one.
        For each signature, the outputs are provided by the corresponding
        `ExportOutput`s, and the inputs are always the input receivers provided by
        the serving_input_receiver_fn.
        Extra assets may be written into the SavedModel via the extra_assets
        argument.  This should be a dict, where each key gives a destination path
        (including the filename) relative to the assets.extra directory.  The
        corresponding value gives the full path of the source file to be copied.
        For example, the simple case of copying a single file without renaming it
        is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
        Args:
          export_dir_base: A string containing a directory in which to create
            timestamped subdirectories containing exported SavedModels.
          serving_input_receiver_fn: A function that takes no argument and
            returns a `ServingInputReceiver`.
          assets_extra: A dict specifying how to populate the assets.extra directory
            within the exported SavedModel, or `None` if no extra assets are needed.
          as_text: whether to write the SavedModel proto in text format.
          checkpoint_path: The checkpoint path to export.  If `None` (the default),
            the most recent checkpoint found within the model directory is chosen.
        Returns:
          The string path to the exported directory.
        Raises:
          ValueError: if no serving_input_receiver_fn is provided, no export_outputs
              are provided, or no checkpoint can be found.
        """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            training.get_or_create_global_step(g)
            random_seed.set_random_seed(self._config.tf_random_seed)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=model_fn_lib.ModeKeys.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)

            with tf_session.Session() as session:

                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
                    sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                # pylint: disable=protected-access
                local_init_op = (
                    estimator_spec.scaffold.local_init_op
                    or monitored_session.Scaffold._default_local_init_op())
                # pylint: enable=protected-access

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            return export_dir

    @staticmethod
    def _check_hooks(hooks):
        """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
        hooks = list(hooks or [])
        for h in hooks:
            if not isinstance(h, SessionRunHook):
                raise TypeError(
                    "Hooks must be a SessionRunHook, given: {}".format(h))

        return hooks

    def train(self, input_fn=None, steps=None, hooks=None, max_steps=None):
        """Trains a model given training data `x` predictions and `y` labels.

        Args:
            input_fn: Input function returning a tuple of:
                features - `Tensor` or dictionary of string feature name to `Tensor`.
                labels - `Tensor` or dictionary of `Tensor` with labels.
            steps: Number of steps for which to train model. If `None`, train forever.
                'steps' works incrementally. If you call two times fit(steps=10) then
                training occurs in total 20 steps. If you don't want to have incremental
                behaviour please set `max_steps` instead. If set, `max_steps` must be
                `None`.
            hooks: List of `BaseMonitor` subclass instances.
                Used for callbacks inside the training loop.
            max_steps: Number of total steps for which to train model. If `None`,
                train forever. If set, `steps` must be `None`.

            Two calls to `fit(steps=100)` means 200 training iterations.
            On the other hand, two calls to `fit(max_steps=100)` means
            that the second call will not do any iteration since first call did all 100 steps.

        Returns:
            `self`, for chaining.
        """
        if (steps is not None) and (max_steps is not None):
            raise ValueError('Can not provide both steps and max_steps.')
        if steps is not None and steps <= 0:
            raise ValueError('Must specify steps > 0, given: {}'.format(steps))
        if max_steps is not None and max_steps <= 0:
            raise ValueError(
                'Must specify max_steps > 0, given: {}'.format(max_steps))

        if max_steps is not None:
            try:
                start_step = load_variable(self._model_dir,
                                           ops.GraphKeys.GLOBAL_STEP)
                if max_steps <= start_step:
                    logging.info(
                        'Skipping training since max_steps has already saved.')
                    return self
            except:  # pylint: disable=bare-except
                pass

        hooks = self._check_hooks(hooks)
        if steps is not None or max_steps is not None:
            hooks.append(plx_hooks.StopAtStepHook(steps, max_steps))

        loss = self._train_model(input_fn=input_fn, hooks=hooks)
        logging.info('Loss for final step: %s.', loss)
        return self

    def evaluate(self,
                 input_fn=None,
                 steps=None,
                 hooks=None,
                 checkpoint_path=None,
                 name=None):
        """Evaluates given model with provided evaluation data.

        Stop conditions - we evaluate on the given input data until one of the
        following:
        - If `steps` is provided, and `steps` batches of size `batch_size` are
        processed.
        - If `input_fn` is provided, and it raises an end-of-input
        exception (`OutOfRangeError` or `StopIteration`).
        - If `x` is provided, and all items in `x` have been processed.

        The return value is a dict containing the metrics specified in `metrics`, as
        well as an entry `global_step` which contains the value of the global step
        for which this evaluation was performed.

        Args:
            input_fn: Input function returning a tuple of:
                features - Dictionary of string feature name to `Tensor` or `Tensor`.
                labels - `Tensor` or dictionary of `Tensor` with labels.
                If `steps` is not provided, this should raise `OutOfRangeError` or
                `StopIteration` after the desired amount of data (e.g., one epoch) has
                been provided. See "Stop conditions" above for specifics.
            steps: Number of steps for which to evaluate model. If `None`, evaluate
                until `x` is consumed or `input_fn` raises an end-of-input exception.
                See "Stop conditions" above for specifics.
            name: Name of the evaluation if user needs to run multiple evaluations on
                different data sets, such as on training data vs test data.
            checkpoint_path: Path of a specific checkpoint to evaluate. If `None`,
                the latest checkpoint in `model_dir` is used.
            hooks: List of `SessionRunHook` subclass instances.
                Used for callbacks inside the evaluation call.

        Raises:
            ValueError: If `metrics` is not `None` or `dict`.

        Returns:
            Returns `dict` with evaluation results.
        """
        hooks = self._check_hooks(hooks)
        if steps is not None:
            if steps <= 0:
                raise ValueError(
                    'Must specify steps > 0, given: {}'.format(steps))
            hooks.append(evaluation._StopAfterNEvalsHook(num_evals=steps))
        return self._evaluate_model(input_fn=input_fn,
                                    name=name,
                                    checkpoint_path=checkpoint_path,
                                    hooks=hooks)

    def predict(self,
                input_fn=None,
                predict_keys=None,
                hooks=None,
                checkpoint_path=None):
        """Returns predictions for given features.

        Args:
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = self._check_hooks(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError("Could not find trained model at %s." %
                             self._model_dir)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.get_or_create_global_step(g)
            features = self._get_features_from_input_fn(input_fn)
            estimator_spec = self._call_model_fn(features, None,
                                                 ModeKeys.PREDICT)
            predictions = self._extract_keys(estimator_spec.predictions,
                                             predict_keys)
            with monitored_session.MonitoredSession(
                    session_creator=monitored_session.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks) as mon_sess:
                while not mon_sess.should_stop():
                    preds_evaluated = mon_sess.run(predictions)
                    if not isinstance(predictions, dict):
                        for pred in preds_evaluated:
                            yield pred
                    else:
                        for i in range(extract_batch_length(preds_evaluated)):
                            yield {
                                key: value[i]
                                for key, value in six.iteritems(
                                    preds_evaluated)
                            }

    def get_variable_value(self, name):
        """Returns value of the variable given by name.

        Args:
            name: string, name of the tensor.

        Returns:
            Numpy array - value of the tensor.
        """
        return load_variable(self.model_dir, name)

    def get_variable_names(self):
        """Returns list of all variable names in this model.

        Returns:
            List of names.
        """
        return [name for name, _ in list_variables(self.model_dir)]

    @staticmethod
    def _extract_metric_update_ops(eval_dict):
        """Separate update operations from metric value operations."""
        update_ops = []
        value_ops = {}
        # Sort metrics lexicographically so graph is identical every time.
        for name, metric_ops in sorted(six.iteritems(eval_dict)):
            value_ops[name] = metric_ops[0]
            update_ops.append(metric_ops[1])

        if update_ops:
            update_op = control_flow_ops.group(*update_ops)
        else:
            update_op = None

        return update_op, value_ops

    @staticmethod
    def _get_features_from_input_fn(input_fn):
        result = input_fn()
        if not ops.get_default_graph().get_collection(
                ops.GraphKeys.QUEUE_RUNNERS):
            logging.warning('Input graph does not contain a QueueRunner. '
                            'That means predict yields forever. '
                            'This is probably a mistake.')
        if isinstance(result, (list, tuple)):
            return result[0]
        return result

    @staticmethod
    def _extract_keys(predictions, predict_keys):
        if not predict_keys:
            return predictions
        if not isinstance(predictions, dict):
            raise ValueError(
                "predict_keys argument is not valid in case of non-dict predictions."
            )
        existing_keys = predictions.keys()
        predictions = {
            key: value
            for key, value in six.iteritems(predictions) if key in predict_keys
        }
        if not predictions:
            raise ValueError("Expected to run at least one output from {}, "
                             "provided {}.".format(existing_keys,
                                                   predict_keys))
        return predictions

    def _train_model(self, input_fn, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            features, labels = input_fn()
            estimator_spec = self._call_model_fn(features, labels,
                                                 ModeKeys.TRAIN)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.CheckpointSaverHook)
                    for h in (all_hooks + estimator_spec.training_hooks +
                              chief_hooks +
                              estimator_spec.training_chief_hooks)
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        plx_hooks.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks + estimator_spec.training_hooks,
                    chief_only_hooks=chief_hooks +
                    estimator_spec.training_chief_hooks,
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            summary_io.SummaryWriterCache.clear()
            return loss

    def _evaluate_model(self,
                        input_fn,
                        hooks=None,
                        checkpoint_path=None,
                        name=''):
        # Check that model has been trained (if nothing has been set explicitly).
        if not checkpoint_path:
            latest_path = saver.latest_checkpoint(self._model_dir)
            if not latest_path:
                raise ValueError("Could not find trained model at {}.".format(
                    self._model_dir))
            checkpoint_path = latest_path

        # Setup output directory.
        eval_dir = os.path.join(self._model_dir,
                                'eval' if not name else 'eval_' + name)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.create_global_step(g)
            features, labels = input_fn()

            estimator_spec = self._call_model_fn(features, labels,
                                                 ModeKeys.EVAL)
            if model_fn_lib.MetricKeys.LOSS in estimator_spec.eval_metric_ops:
                raise ValueError(
                    "Metric with name `{}` is not allowed, because Estimator "
                    "already defines a default metric "
                    "with the same name.".format(model_fn_lib.MetricKeys.LOSS))
            estimator_spec.eval_metric_ops[
                model_fn_lib.MetricKeys.LOSS] = metrics_lib.streaming_mean(
                    estimator_spec.loss)
            update_op, eval_dict = self._extract_metric_update_ops(
                estimator_spec.eval_metric_ops)

            if ops.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError(
                    "Metric with name `global_step` is not allowed, because "
                    "Estimator already defines a default metric with the same name."
                )
            eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step

            eval_results = evaluation._evaluate_once(
                checkpoint_path=checkpoint_path,
                master=self._config.evaluation_master,
                scaffold=estimator_spec.scaffold,
                eval_ops=update_op,
                final_ops=eval_dict,
                hooks=hooks,
                config=self._session_config)

            self._write_dict_to_summary(
                output_dir=eval_dir,
                dictionary=eval_results,
                current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])

            return eval_results

    @staticmethod
    def _write_dict_to_summary(output_dir, dictionary, current_global_step):
        """Writes a `dict` into summary file in given output directory.

          Args:
            output_dir: `str`, directory to write the summary file in.
            dictionary: the `dict` to be written to summary file.
            current_global_step: `int`, the current global step.
          """
        logging.info('Saving dict for global step %d: %s', current_global_step,
                     dict_to_str(dictionary))
        summary_writer = summary_io.SummaryWriterCache.get(output_dir)
        summary_proto = summary_pb2.Summary()
        for key in dictionary:
            if dictionary[key] is None or key == tf.GraphKeys.GLOBAL_STEP:
                continue
            value = summary_proto.value.add()
            value.tag = key
            if isinstance(dictionary[key], (np.float32, float)):
                value.simple_value = float(dictionary[key])
            elif isinstance(dictionary[key], (int, np.int64, np.int32)):
                value.simple_value = int(dictionary[key])
            else:
                logging.warn(
                    'Skipping summary for %s, must be a '
                    'float, np.float32, int, int32, or int64.', key)
        summary_writer.add_summary(summary_proto, current_global_step)
        summary_writer.flush()
Пример #7
0
class Estimator(object):
    """Estimator class is a model trainer/evaluator.

    Constructs an `Estimator` instance.

    Args:
        model_fn: Model function. Follows the signature:
            * Args:
                * `features`: single `Tensor` or `dict` of `Tensor`s
                     (depending on data passed to `fit`),
                * `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head models).
                    If mode is `Modes.PREDICT`, `labels=None` will be passed.
                    If the `model_fn`'s signature does not accept `mode`,
                    the `model_fn` must still be able to handle `labels=None`.
                * `mode`: Specifies if this training, evaluation or prediction. See `Modes`.
                * `params`: Optional `dict` of hyperparameters.  Will receive what
                    is passed to Estimator in `params` parameter. This allows
                    to configure Estimators from hyper parameter tuning.
                * `config`: Optional configuration object. Will receive what is passed
                    to Estimator in `config` parameter, or the default `config`.
                    Allows updating things in your model_fn based on configuration
                    such as `num_ps_replicas`.
                * `model_dir`: Optional directory where model parameters, graph etc
                    are saved. Will receive what is passed to Estimator in
                    `model_dir` parameter, or the default `model_dir`. Allows
                    updating things in your model_fn that expect model_dir, such as
                    training hooks.

            * Returns:
               `EstimatorSpec`

            Supports next three signatures for the function:

                * `(features, labels, mode)`
                * `(features, labels, mode, params)`
                * `(features, labels, mode, params, config)`
                * `(features, labels, mode, params, config, model_dir)`

        model_dir: Directory to save model parameters, graph and etc. This can
            also be used to load checkpoints from the directory into a estimator to
            continue training a previously saved model.
        config: Configuration object.
        params: `dict` of hyper parameters that will be passed into `model_fn`.
                  Keys are names of parameters, values are basic python types.
    Raises:
        ValueError: parameters of `model_fn` don't match `params`.
    """
    def __init__(self, model_fn, model_dir=None, config=None, params=None):
        # Create a run configuration.
        if config is None:
            self._config = RunConfig()
            logging.info("Using default config.")
        else:
            if not isinstance(config, RunConfig):
                raise ValueError("config must be an instance of RunConfig, "
                                 "received {}.".format(config))
            self._config = config

        if(model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))

        self._model_dir = model_dir or self._config.model_dir or generate_model_dir()
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)
        logging.info("Using config: {}".format(vars(self._config)))

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        # Set device function depending if there are replicas or not.
        self._device_fn = _get_replica_device_setter(self._config)

        self._graph = None

        self._verify_model_fn_args(model_fn, params)

        self._model_fn = model_fn
        self._params = params or {}

    @property
    def model_dir(self):
        return self._model_dir

    @property
    def config(self):
        return copy.deepcopy(self._config)

    @property
    def params(self):
        return copy.deepcopy(self._params)

    @staticmethod
    def _verify_model_fn_args(model_fn, params):
        """Verifies model fn arguments."""

        MODEL_FN_ARGS = {'features', 'labels', 'mode', 'params', 'config'}

        if model_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(model_fn)
            if 'features' not in model_fn_args:
                raise ValueError('model_fn `{}` must include features argument.'.format(model_fn))
            if 'labels' not in model_fn_args:
                raise ValueError('model_fn `{}` must include labels argument.'.format(model_fn))

            if params is not None and 'params' not in model_fn_args:
                raise ValueError("Estimator's model_fn `{}` does not include params argument, "
                                 "but params `{}` are passed.".format(model_fn, params))
            if params is None and 'params' in model_fn_args:
                logging.warning("Estimator's model_fn (%s) includes params "
                                "argument, but params are not passed to Estimator.", model_fn)
        else:
            raise ValueError("`model_fn` must be provided to Estimator.")

        if 'self' in model_fn_args:
            model_fn_args.remove('self')

        non_valid_args = set(model_fn_args) - MODEL_FN_ARGS
        if non_valid_args:
            raise ValueError("model_fn `{}` has following not expected args: {}".format(
                model_fn, non_valid_args))

    def _call_model_fn(self, features, labels, mode):
        """Calls model function with support of 2, 3 or 4 arguments.

        Args:
            features: features dict.
            labels: labels dict.
            mode: Modes

        Returns:
            A `ModelFnOps` object.
            If model_fn returns a tuple, wraps them up in a `ModelFnOps` object.

        Raises:
            ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features, labels=labels, **kwargs)

        if not isinstance(model_fn_results, EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results

    def export_savedmodel(self, export_dir_base, serving_input_receiver_fn, assets_extra=None,
                          as_text=False, checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.
        This method builds a new graph by first calling the serving_input_receiver_fn to
        obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn
        to generate the model graph based on those features. It restores the given checkpoint
        (or, lacking that, the most recent checkpoint) into this graph in a fresh session.
        Finally it creates a timestamped export directory below the given export_dir_base,
        and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this
        session.
        The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the
        export_outputs dict returned from the model_fn, named using the same keys.
        One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
        indicating which signature will be served when a serving request does not specify one.
        For each signature, the outputs are provided by the corresponding `ExportOutput`s,
        and the inputs are always the input receivers provided by the serving_input_receiver_fn.
        Extra assets may be written into the SavedModel via the extra_assets argument.
        This should be a dict, where each key gives a destination path (including the filename)
        relative to the assets.extra directory.  The corresponding value gives the full path of
        the source file to be copied. For example, the simple case of copying a single file without
        renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

        Args:
            export_dir_base: A string containing a directory in which to create
                timestamped subdirectories containing exported SavedModels.
            serving_input_receiver_fn: A function that takes no argument and
                returns a `ServingInputReceiver`.
            assets_extra: A dict specifying how to populate the assets.extra directory
                within the exported SavedModel, or `None` if no extra assets are needed.
            as_text: whether to write the SavedModel proto in text format.
            checkpoint_path: The checkpoint path to export.  If `None` (the default),
                the most recent checkpoint found within the model directory is chosen.
        Returns:
            The string path to the exported directory.
        Raises:
            ValueError: if no serving_input_receiver_fn is provided, no export_outputs
                are provided, or no checkpoint can be found.
        """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            training.get_or_create_global_step(g)
            random_seed.set_random_seed(self._config.tf_random_seed)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=Modes.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." % self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)

            with tf_session.Session() as session:
                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(sharded=True)
                saver_for_restore.restore(session, checkpoint_path)
                local_init_op = (estimator_spec.scaffold.local_init_op or
                                 monitored_session.Scaffold._default_local_init_op())
                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                                 compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                                 compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            return export_dir

    @staticmethod
    def _check_hooks(hooks):
        """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
        hooks = list(hooks or [])
        for h in hooks:
            if not isinstance(h, SessionRunHook):
                raise TypeError("Hooks must be a SessionRunHook, given: {}".format(h))

        return hooks

    def _prepare_train(self, steps=None, hooks=None, max_steps=None):
        """Checks train specifications (steps and hooks) and return the train hooks.

        Args:
            steps: Number of steps for which to train model. If `None`, train forever.
                'steps' works incrementally. If you call two times fit(steps=10) then
                training occurs in total 20 steps. If you don't want to have incremental
                behaviour please set `max_steps` instead. If set, `max_steps` must be
                `None`.
            hooks: List of `BaseMonitor` subclass instances.
                Used for callbacks inside the training loop.
            max_steps: Number of total steps for which to train model. If `None`,
                train forever. If set, `steps` must be `None`.

        Returns:
            `list`: An updated version of hooks.
        """
        if (steps is not None) and (max_steps is not None):
            raise ValueError("Can not provide both steps and max_steps.")
        if steps is not None and steps <= 0:
            raise ValueError("Must specify steps > 0, given: {}".format(steps))
        if max_steps is not None and max_steps <= 0:
            raise ValueError("Must specify max_steps > 0, given: {}".format(max_steps))

        hooks = self._check_hooks(hooks)
        if steps is not None or max_steps is not None:
            hooks.append(plx_hooks.StopAtStepHook(steps, max_steps))

        return hooks

    def train(self, input_fn=None, steps=None, hooks=None, max_steps=None):
        """Trains a model given training data `x` predictions and `y` labels.

        Args:
            input_fn: Input function returning a tuple of:
                features - `Tensor` or dictionary of string feature name to `Tensor`.
                labels - `Tensor` or dictionary of `Tensor` with labels.
            steps: Number of steps for which to train model. If `None`, train forever.
                'steps' works incrementally. If you call two times fit(steps=10) then
                training occurs in total 20 steps. If you don't want to have incremental
                behaviour please set `max_steps` instead. If set, `max_steps` must be
                `None`.
            hooks: List of `BaseMonitor` subclass instances.
                Used for callbacks inside the training loop.
            max_steps: Number of total steps for which to train model. If `None`,
                train forever. If set, `steps` must be `None`.

            Two calls to `fit(steps=100)` means 200 training iterations.
            On the other hand, two calls to `fit(max_steps=100)` means
            that the second call will not do any iteration since first call did all 100 steps.

        Returns:
            `self`, for chaining.
        """
        if max_steps is not None:
            try:
                start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
                if max_steps <= start_step:
                    logging.info('Skipping training since max_steps has already saved.')
                    return self
            except:  # pylint: disable=bare-except
                pass

        hooks = self._prepare_train(steps, hooks, max_steps)
        loss = self._train_model(input_fn=input_fn, hooks=hooks)
        logging.info('Loss for final step: %s.', loss)
        return self

    def evaluate(self, input_fn=None, steps=None, hooks=None, checkpoint_path=None, name=None):
        """Evaluates given model with provided evaluation data.

        Stop conditions - we evaluate on the given input data until one of the
        following:
        - If `steps` is provided, and `steps` batches of size `batch_size` are processed.
        - If `input_fn` is provided, and it raises an end-of-input
        exception (`OutOfRangeError` or `StopIteration`).
        - If `x` is provided, and all items in `x` have been processed.

        Args:
            input_fn: Input function returning a tuple of:
                features - Dictionary of string feature name to `Tensor` or `Tensor`.
                labels - `Tensor` or dictionary of `Tensor` with labels.
                If `steps` is not provided, this should raise `OutOfRangeError` or
                `StopIteration` after the desired amount of data (e.g., one epoch) has
                been provided. See "Stop conditions" above for specifics.
            steps: Number of steps for which to evaluate model. If `None`, evaluate
                until `x` is consumed or `input_fn` raises an end-of-input exception.
                See "Stop conditions" above for specifics.
            checkpoint_path: Path of a specific checkpoint to evaluate. If `None`,
                the latest checkpoint in `model_dir` is used.
            hooks: List of `SessionRunHook` subclass instances.
                Used for callbacks inside the evaluation call.
            name: Name of the evaluation if user needs to run multiple evaluations on
                different data sets, such as on training data vs test data.

        Raises:
            ValueError: If `metrics` is not `None` or `dict`.

        Returns:
            Returns `dict` with evaluation results; the metrics specified in `metrics`, as
            well as an entry `global_step` which contains the value of the global step
            for which this evaluation was performed.
        """
        hooks = self._check_hooks(hooks)
        if steps is not None:
            if steps <= 0:
                raise ValueError('Must specify steps > 0, given: {}'.format(steps))
            hooks.append(plx_hooks.StopAfterNEvalsHook(num_evals=steps))
        return self._evaluate_model(
            input_fn=input_fn, name=name, checkpoint_path=checkpoint_path, hooks=hooks)

    def predict(self, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None):
        """Returns predictions for given features with `PREDICT` mode.

        Args:
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        return self._infer_model(Modes.PREDICT, input_fn=input_fn, predict_keys=predict_keys,
                                 hooks=hooks, checkpoint_path=checkpoint_path)

    def generate(self, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None):
        """Returns predictions for given features with `GENERATE` mode.

        Args:
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        return self._infer_model(Modes.GENERATE, input_fn=input_fn, predict_keys=predict_keys,
                                 hooks=hooks, checkpoint_path=checkpoint_path)

    def encode(self, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None):
        """Returns predictions for given features with `ENCODE` mode.

        Args:
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        return self._infer_model(Modes.ENCODE, input_fn=input_fn, predict_keys=predict_keys,
                                 hooks=hooks, checkpoint_path=checkpoint_path)

    def get_variable_value(self, name):
        """Returns value of the variable given by name.

        Args:
            name: string, name of the tensor.

        Returns:
            Numpy array - value of the tensor.
        """
        return load_variable(self.model_dir, name)

    def get_variable_names(self):
        """Returns list of all variable names in this model.

        Returns:
            List of names.
        """
        return [name for name, _ in list_variables(self.model_dir)]

    @staticmethod
    def _extract_metric_update_ops(eval_dict):
        """Separate update operations from metric value operations."""
        update_ops = []
        value_ops = {}
        # Sort metrics lexicographically so graph is identical every time.
        for name, metric_ops in sorted(six.iteritems(eval_dict)):
            value_ops[name] = metric_ops[0]
            update_ops.append(metric_ops[1])

        if update_ops:
            update_op = control_flow_ops.group(*update_ops)
        else:
            update_op = None

        return update_op, value_ops

    @staticmethod
    def _get_features_from_input_fn(input_fn):
        result = input_fn()
        if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
            logging.warning('Input graph does not contain a QueueRunner. '
                            'That means predict yields forever. '
                            'This is probably a mistake.')
        if isinstance(result, (list, tuple)):
            return result[0]
        return result

    @staticmethod
    def _extract_keys(predictions, predict_keys):
        if not predict_keys:
            return predictions
        if not isinstance(predictions, dict):
            raise ValueError("predict_keys argument is not valid in case of non-dict predictions.")
        existing_keys = predictions.keys()
        predictions = {
            key: value
            for key, value in six.iteritems(predictions) if key in predict_keys
        }
        if not predictions:
            raise ValueError("Expected to run at least one output from {}, "
                             "provided {}.".format(existing_keys, predict_keys))
        return predictions

    def _train_model(self, input_fn, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            with ops.device('/cpu:0'):
                features, labels = input_fn()
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                                      saver.Saver(sharded=True,  # TODO `var_list`
                                                  max_to_keep=self._config.keep_checkpoint_max,
                                                  defer_build=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.StepCheckpointSaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepCheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.StepSummarySaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepSummarySaverHook(
                            scaffold=scaffold,
                            save_steps=self._config.save_summary_steps,
                            output_dir=self._model_dir,
                        )
                    ]

            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
            summary_io.SummaryWriterCache.clear()
            return loss

    def _evaluate_model(self, input_fn, hooks=None, checkpoint_path=None, name=''):
        # Check that model has been trained (if nothing has been set explicitly).
        if not checkpoint_path:
            latest_path = saver.latest_checkpoint(self._model_dir)
            if not latest_path:
                error_message = "Could not find trained model at {}.".format(self._model_dir)
                raise EstimatorNotTrainedError(error_message)
            checkpoint_path = latest_path

        # Setup output directory.
        eval_dir = os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.create_global_step(g)
            features, labels = input_fn()

            estimator_spec = self._call_model_fn(features, labels, Modes.EVAL)
            if MetricKeys.LOSS in estimator_spec.eval_metric_ops:
                raise ValueError("Metric with name `{}` is not allowed, because Estimator "
                                 "already defines a default metric "
                                 "with the same name.".format(MetricKeys.LOSS))
            estimator_spec.eval_metric_ops[
                MetricKeys.LOSS] = metrics_lib.streaming_mean(estimator_spec.loss)
            update_op, eval_dict = self._extract_metric_update_ops(estimator_spec.eval_metric_ops)

            if ops.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError("Metric with name `global_step` is not allowed, because "
                                 "Estimator already defines a default metric with the same name.")
            eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step

            eval_results = evaluation._evaluate_once(
                checkpoint_path=checkpoint_path,
                master=self._config.evaluation_master,
                scaffold=estimator_spec.scaffold,
                eval_ops=update_op,
                final_ops=eval_dict,
                hooks=hooks,
                config=self._session_config)

            self._write_dict_to_summary(
                output_dir=eval_dir,
                dictionary=eval_results,
                current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])

            return eval_results

    def _infer_model(self, mode, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None):
        """Returns predictions for given features given an inference mode.

        Args:
            mode: The inference to use, possible values: PREDICT, GENERATE, ENCODE.
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = self._check_hooks(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError("Could not find trained model at %s." % self._model_dir)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.get_or_create_global_step(g)
            features = self._get_features_from_input_fn(input_fn)
            estimator_spec = self._call_model_fn(features, None, mode)
            predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
            with monitored_session.MonitoredSession(
                    session_creator=monitored_session.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks) as mon_sess:
                while not mon_sess.should_stop():
                    preds_evaluated = mon_sess.run(predictions)
                    if not isinstance(predictions, dict):
                        for pred in preds_evaluated:
                            yield pred
                    else:
                        for i in xrange(extract_batch_length(preds_evaluated)):
                            yield {key: value[i] for key, value in six.iteritems(preds_evaluated)}

    @staticmethod
    def _write_dict_to_summary(output_dir,
                               dictionary,
                               current_global_step):
        """Writes a `dict` into summary file in given output directory.

          Args:
            output_dir: `str`, directory to write the summary file in.
            dictionary: the `dict` to be written to summary file.
            current_global_step: `int`, the current global step.
          """
        logging.info('Saving dict for global step %d: %s', current_global_step,
                     dict_to_str(dictionary))
        summary_writer = summary_io.SummaryWriterCache.get(output_dir)
        summary_proto = summary_pb2.Summary()
        for key in dictionary:
            if dictionary[key] is None or key == tf.GraphKeys.GLOBAL_STEP:
                continue
            value = summary_proto.value.add()
            value.tag = key
            if isinstance(dictionary[key], (np.float32, float)):
                value.simple_value = float(dictionary[key])
            elif isinstance(dictionary[key], (int, np.int64, np.int32)):
                value.simple_value = int(dictionary[key])
            else:
                logging.warn('Skipping summary for %s, must be a '
                             'float, np.float32, int, int32, or int64.', key)
        summary_writer.add_summary(summary_proto, current_global_step)
        summary_writer.flush()