示例#1
0
    def train(self,
              input_fn,
              steps=None,
              hooks=None,
              max_steps=None,
              saving_listeners=None,
              save_best_ckpt=False):
        with context.graph_mode():
            if (steps is not None) and (max_steps is not None):
                raise ValueError('Can not provide both steps and max_steps.')
            if steps is not None and steps <= 0:
                raise ValueError(
                    'Must specify steps > 0, given: {}'.format(steps))
            if max_steps is not None and max_steps <= 0:
                raise ValueError(
                    'Must specify max_steps > 0, given: {}'.format(max_steps))

            if max_steps is not None:
                start_step = _load_global_step_from_checkpoint_dir(
                    self._model_dir)
                if max_steps <= start_step:
                    logging.info(
                        'Skipping training since max_steps has already saved.')
                    return self

            hooks = estimator_lib._check_hooks_type(hooks)
            hooks.append(training.StopAtStepHook(steps, max_steps))

            saving_listeners = estimator_lib._check_listeners_type(
                saving_listeners)
            loss = self._train_model(input_fn, hooks, saving_listeners,
                                     save_best_ckpt)
            logging.info('Loss for final step: %s.', loss)
            return self
示例#2
0
    def predict(
        self,
        input_fn,
        predict_keys=None,
        hooks=None,
        checkpoint_dir=None,
        yield_single_examples=True,
    ):
        """Arguments are same with Estimator.predict"""
        with context.graph_mode():
            hooks = estimator._check_hooks_type(hooks)
            # Check that model has been trained.
            if not checkpoint_dir:
                raise ValueError("No checkpoint_dir")
            with ops.Graph().as_default() as g, g.device(self._device_fn):
                random_seed.set_random_seed(self._config.tf_random_seed)
                self._create_and_assert_global_step(g)
                features, input_hooks = self._get_features_from_input_fn(
                    input_fn, model_fn_lib.ModeKeys.PREDICT
                )
                estimator_spec = self._call_model_fn(
                    features,
                    None,
                    model_fn_lib.ModeKeys.PREDICT,
                    self.config,
                )

                predictions = self._extract_keys(
                    estimator_spec.predictions, predict_keys
                )
                all_hooks = list(input_hooks)
                all_hooks.extend(hooks)
                all_hooks.extend(
                    list(estimator_spec.prediction_hooks or [])
                )
                with training.MonitoredTrainingSession(
                    is_chief=args.worker_type=="chief",
                    master=config.master,
                    checkpoint_dir=checkpoint_dir,
                    config=config.session_config,
                ) as mon_sess:

                    while not mon_sess.should_stop():
                        preds_evaluated = mon_sess.run(predictions)
                        if not yield_single_examples:
                            yield preds_evaluated
                        elif not isinstance(predictions, dict):
                            for pred in preds_evaluated:
                                yield pred
                        else:
                            for i in range(
                                self._extract_batch_length(preds_evaluated)
                            ):
                                yield {
                                    key: value[i]
                                    for key, value in six.iteritems(
                                        preds_evaluated
                                    )
                                }
示例#3
0
    def __init__(self,
                 estimator,
                 input_fn,
                 steps=None,
                 hooks=None,
                 name=None,
                 every_n_iter=100):
        """Initializes a `InMemoryEvaluatorHook`.

    Args:
      estimator: A `tf.estimator.Estimator` instance to call evaluate.
      input_fn:  Equivalent to the `input_fn` arg to `estimator.evaluate`. A
        function that constructs the input data for evaluation.
        See [Creating input functions](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:

          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple (features, labels) with same constraints as below.
          * A tuple (features, labels): Where `features` is a `Tensor` or a
            dictionary of string feature name to `Tensor` and `labels` is a
            `Tensor` or a dictionary of string label name to `Tensor`. Both
            `features` and `labels` are consumed by `model_fn`. They should
            satisfy the expectation of `model_fn` from inputs.

      steps: Equivalent to the `steps` arg to `estimator.evaluate`.  Number of
        steps for which to evaluate model. If `None`, evaluates until `input_fn`
        raises an end-of-input exception.
      hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of
        `SessionRunHook` subclass instances. Used for callbacks inside the
        evaluation call.
      name:  Equivalent to the `name` arg to `estimator.evaluate`. Name of the
        evaluation if user needs to run multiple evaluations on different data
        sets, such as on training data vs test data. Metrics for different
        evaluations are saved in separate folders, and appear separately in
        tensorboard.
      every_n_iter: `int`, runs the evaluator once every N training iteration.

    Raises:
      ValueError: if `every_n_iter` is non-positive or it's not a single machine
        training
    """
        if every_n_iter is None or every_n_iter <= 0:
            raise ValueError('invalid every_n_iter=%s.' % every_n_iter)
        if (estimator.config.num_ps_replicas > 0
                or estimator.config.num_worker_replicas > 1):
            raise ValueError(
                'InMemoryEvaluator supports only single machine (aka Local) setting.'
            )
        self._estimator = estimator
        self._input_fn = input_fn
        self._steps = steps
        self._name = name
        self._every_n_iter = every_n_iter
        self._eval_dir = os.path.join(self._estimator.model_dir,
                                      'eval' if not name else 'eval_' + name)

        self._graph = None
        self._hooks = estimator_lib._check_hooks_type(hooks)
        self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps))
        self._timer = tf.compat.v1.train.SecondOrStepTimer(
            every_steps=every_n_iter)
示例#4
0
    def predict_with_guide(self,
                           input_fn,
                           predict_keys=None,
                           hooks=None,
                           checkpoint_path=None,
                           latest_filename=None,
                           yield_single_examples=True):
        hooks = estimator_lib._check_hooks_type(hooks)

        checkpoint_path = self._checkpoint_path(checkpoint_path,
                                                latest_filename)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            self._create_and_assert_global_step(g)
            features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
                input_fn, model_fn_lib.ModeKeys.EVAL)

            features_ph = {
                key: array_ops.placeholder(value.dtype, value.shape, name=key)
                for key, value in features.items()
            }
            labels_ph = array_ops.placeholder(labels.dtype,
                                              labels.shape,
                                              name="labels")
            feed_guide_hook = FeedGuideHook(features_ph, labels_ph, features,
                                            labels, self.model_dir)

            estimator_spec = self._call_model_fn(features_ph, labels_ph,
                                                 model_fn_lib.ModeKeys.PREDICT,
                                                 self.config)

            if isinstance(predict_keys, list):
                predict_keys += list(
                    self.params["model_instances"][0].metrics_dict.keys())
            elif predict_keys is None:
                # Evaluating volume don't need metrics in model, we use XXXPred to generate 3D predict
                predict_keys = [
                    x for x in estimator_spec.predictions
                    if x not in self.params["model_instances"][0].metrics_dict
                ]
                predict_keys.extend(
                    list(self.params["model_instances"][0].metrics_eval))
            else:
                raise TypeError(
                    "predict_keys must be None(for 3d eval) or a list(for 2d eval, "
                    "for example [\"Names\", \"Indices\"])")
            predictions = self._extract_keys(estimator_spec.predictions,
                                             predict_keys)
            feed_guide_hook.predictions = predictions

            all_hooks = list(input_hooks) + [feed_guide_hook]
            all_hooks.extend(hooks)
            all_hooks.extend(list(estimator_spec.prediction_hooks or []))

            with training.MonitoredSession(
                    session_creator=training.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        master=self._config.master,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=all_hooks) as mon_sess:
                while not mon_sess.should_stop():
                    preds_evaluated = mon_sess.run(predictions)
                    if not yield_single_examples:
                        yield preds_evaluated
                    elif not isinstance(predictions, dict):
                        for pred in preds_evaluated:
                            yield pred
                    else:
                        for i in range(
                                self._extract_batch_length(preds_evaluated)):
                            yield {
                                key: value[i]
                                for key, value in six.iteritems(
                                    preds_evaluated)
                            }