Exemplo n.º 1
0
    def _infer_model(self, mode, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None):
        """Returns predictions for given features given an inference mode.

        Args:
            mode: The inference to use, possible values: PREDICT, GENERATE, ENCODE.
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = self._check_hooks(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError("Could not find trained model at %s." % self._model_dir)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.get_or_create_global_step(g)
            features = self._get_features_from_input_fn(input_fn)
            estimator_spec = self._call_model_fn(features, None, mode)
            predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
            with monitored_session.MonitoredSession(
                    session_creator=monitored_session.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks) as mon_sess:
                while not mon_sess.should_stop():
                    preds_evaluated = mon_sess.run(predictions)
                    if not isinstance(predictions, dict):
                        for pred in preds_evaluated:
                            yield pred
                    else:
                        for i in xrange(extract_batch_length(preds_evaluated)):
                            yield {key: value[i] for key, value in six.iteritems(preds_evaluated)}
Exemplo n.º 2
0
  def build_subnetwork(self,
                       features,
                       logits_dimension,
                       training,
                       iteration_step,
                       summary,
                       previous_ensemble=None):
    assert features is not None
    assert training is not None
    assert iteration_step is not None
    assert summary is not None

    # Trainable variables collection should always be empty when
    # build_subnetwork is called.
    assert not tf_compat.v1.get_collection(
        tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)

    # Subnetworks get iteration steps instead of global steps.
    step_name = "subnetwork_test/iteration_step"
    assert step_name == tf_compat.tensor_name(
        tf_compat.v1.train.get_global_step())
    assert step_name == tf_compat.tensor_name(train.get_global_step())
    assert step_name == tf_compat.tensor_name(training_util.get_global_step())
    assert step_name == tf_compat.tensor_name(tf_v1.train.get_global_step())
    assert step_name == tf_compat.tensor_name(
        tf_compat.v1.train.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(train.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(
        training_util.get_or_create_global_step())
    assert step_name == tf_compat.tensor_name(
        tf_v1.train.get_or_create_global_step())

    # Subnetworks get scoped summaries.
    assert "fake_scalar" == tf_compat.v1.summary.scalar("scalar", 1.)
    assert "fake_image" == tf_compat.v1.summary.image("image", 1.)
    assert "fake_histogram" == tf_compat.v1.summary.histogram("histogram", 1.)
    assert "fake_audio" == tf_compat.v1.summary.audio("audio", 1., 1.)
    last_layer = tu.dummy_tensor(shape=(2, 3))

    def logits_fn(logits_dim):
      return tf_compat.v1.layers.dense(
          last_layer,
          units=logits_dim,
          kernel_initializer=tf_compat.v1.glorot_uniform_initializer(
              seed=self._seed))

    if self._multi_head:
      logits = {
          "head1": logits_fn(logits_dimension / 2),
          "head2": logits_fn(logits_dimension / 2)
      }
      last_layer = {"head1": last_layer, "head2": last_layer}
    else:
      logits = logits_fn(logits_dimension)

    return Subnetwork(
        last_layer=logits if self._use_logits_last_layer else last_layer,
        logits=logits,
        complexity=2,
        persisted_tensors={})
Exemplo n.º 3
0
  def _train_model(self, checkpoint_dir, num_steps):
    """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = logistic_classifier(tf_inputs)
      loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions)

      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      train_op = optimizer.minimize(loss_op,
                                    training.get_or_create_global_step())

      with monitored_session.MonitoredTrainingSession(
          checkpoint_dir=checkpoint_dir,
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) as session:
        loss = None
        while not session.should_stop():
          _, loss = session.run([train_op, loss_op])

        if num_steps >= 300:
          assert loss < .015
Exemplo n.º 4
0
    def _build_train_op(self, loss):
        """Creates the training operation"""
        optimizer = self._build_optimizer()
        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=training.get_or_create_global_step(),
            learning_rate=None,
            clip_gradients=self._clip_gradients_fn,
            optimizer=optimizer,
            summaries=[])

        return train_op
Exemplo n.º 5
0
    def _build_train_op(self, loss):
        """Creates the training operation"""
        optimizer = self._build_optimizer()
        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=training.get_or_create_global_step(),
            learning_rate=None,
            clip_gradients=self._clip_gradients_fn,
            optimizer=optimizer,
            summaries=[])

        return train_op
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    model = Model(params["data_format"])
    image = features
    if isinstance(image, dict):
        image = features["image"]

    if mode == estimator.ModeKeys.PREDICT:
        logits = model(image, training=False)
        predictions = {
            "classes": math_ops.argmax(logits, axis=1),
            "probabilities": nn.softmax(logits),
        }
        return estimator.EstimatorSpec(
            mode=estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                "classify": estimator.export.PredictOutput(predictions)
            })

    elif mode == estimator.ModeKeys.TRAIN:
        optimizer = train.AdamOptimizer(learning_rate=1e-4)

        logits = model(image, training=True)
        loss = losses.sparse_softmax_cross_entropy(labels=labels,
                                                   logits=logits)
        return estimator.EstimatorSpec(mode=estimator.ModeKeys.TRAIN,
                                       loss=loss,
                                       train_op=optimizer.minimize(
                                           loss,
                                           train.get_or_create_global_step()))

    elif mode == estimator.ModeKeys.EVAL:
        logits = model(image, training=False)
        loss = losses.sparse_softmax_cross_entropy(labels=labels,
                                                   logits=logits)
        return estimator.EstimatorSpec(
            mode=estimator.ModeKeys.EVAL,
            loss=loss,
            eval_metric_ops={
                "accuracy":
                ops.metrics.accuracy(labels=labels,
                                     predictions=math_ops.argmax(logits,
                                                                 axis=1)),
            })
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = Model(params["data_format"])
  image = features
  if isinstance(image, dict):
    image = features["image"]

  if mode == estimator.ModeKeys.PREDICT:
    logits = model(image, training=False)
    predictions = {
        "classes": math_ops.argmax(logits, axis=1),
        "probabilities": nn.softmax(logits),
    }
    return estimator.EstimatorSpec(
        mode=estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            "classify": estimator.export.PredictOutput(predictions)
        })

  elif mode == estimator.ModeKeys.TRAIN:
    optimizer = train.AdamOptimizer(learning_rate=1e-4)

    logits = model(image, training=True)
    loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    return estimator.EstimatorSpec(
        mode=estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, train.get_or_create_global_step()))

  elif mode == estimator.ModeKeys.EVAL:
    logits = model(image, training=False)
    loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    return estimator.EstimatorSpec(
        mode=estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            "accuracy":
                ops.metrics.accuracy(
                    labels=labels, predictions=math_ops.argmax(logits, axis=1)),
        })
Exemplo n.º 8
0
    def _train_model(self, checkpoint_dir, num_steps):
        """Trains a simple classification model.

    Note that the data has been configured such that after around 300 steps,
    the model has memorized the dataset (e.g. we can expect %100 accuracy).

    Args:
      checkpoint_dir: The directory where the checkpoint is written to.
      num_steps: The number of steps to train for.
    """
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = logistic_classifier(tf_inputs)
            loss_op = losses.log_loss(labels=tf_labels,
                                      predictions=tf_predictions)

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            train_op = optimizer.minimize(loss_op,
                                          training.get_or_create_global_step())

            with monitored_session.MonitoredTrainingSession(
                    checkpoint_dir=checkpoint_dir,
                    hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)
                           ]) as session:
                loss = None
                while not session.should_stop():
                    _, loss = session.run([train_op, loss_op])

                if num_steps >= 300:
                    assert loss < .015
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(head=head)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", 0., trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ComplexityRegularizedEnsembler(
                    mixture_weight_type=mixture_weight_type,
                    mixture_weight_initializer=mixture_weight_initializer,
                    warm_start_mixture_weights=warm_start_mixture_weights,
                    model_dir=self.test_subdirectory,
                    adanet_lambda=adanet_lambda,
                    adanet_beta=adanet_beta,
                    use_bias=use_bias),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                iteration_step=tf_compat.v1.train.get_or_create_global_step(),
                labels=labels,
                mode=mode)

            with tf_compat.v1.Session(graph=g).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                # Bias should learn a non-zero value when used.
                bias = sess.run(ensemble_spec.ensemble.bias)
                if isinstance(bias, dict):
                    bias = sum(abs(b) for b in bias.values())
                if use_bias:
                    self.assertNotEqual(0., bias)
                else:
                    self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)
Exemplo n.º 10
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            features, labels = input_fn()
            estimator_spec = self._call_model_fn(features, labels,
                                                 ModeKeys.TRAIN)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.CheckpointSaverHook)
                    for h in (all_hooks + estimator_spec.training_hooks +
                              chief_hooks +
                              estimator_spec.training_chief_hooks)
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        plx_hooks.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks + estimator_spec.training_hooks,
                    chief_only_hooks=chief_hooks +
                    estimator_spec.training_chief_hooks,
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            summary_io.SummaryWriterCache.clear()
            return loss
Exemplo n.º 11
0
    def predict(self,
                input_fn=None,
                predict_keys=None,
                hooks=None,
                checkpoint_path=None):
        """Returns predictions for given features.

        Args:
            input_fn: Input function returning features which is a dictionary of
                string feature name to `Tensor` or `SparseTensor`. If it returns a
                tuple, first item is extracted as features. Prediction continues until
                `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`).
            predict_keys: list of `str`, name of the keys to predict. It is used if
                the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest
                of the predictions will be filtered from the dictionary. If `None`, returns all.
            hooks: List of `SessionRunHook` subclass instances. Used for callbacks
                inside the prediction call.
            checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
                latest checkpoint in `model_dir` is used.

        Yields:
            Evaluated values of `predictions` tensors.

        Raises:
            ValueError: Could not find a trained model in model_dir.
            ValueError: if batch length of predictions are not same.
            ValueError: If there is a conflict between `predict_keys` and `predictions`.
                For example if `predict_keys` is not `None`
                but `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = self._check_hooks(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError("Could not find trained model at %s." %
                             self._model_dir)

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.get_or_create_global_step(g)
            features = self._get_features_from_input_fn(input_fn)
            estimator_spec = self._call_model_fn(features, None,
                                                 ModeKeys.PREDICT)
            predictions = self._extract_keys(estimator_spec.predictions,
                                             predict_keys)
            with monitored_session.MonitoredSession(
                    session_creator=monitored_session.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks) as mon_sess:
                while not mon_sess.should_stop():
                    preds_evaluated = mon_sess.run(predictions)
                    if not isinstance(predictions, dict):
                        for pred in preds_evaluated:
                            yield pred
                    else:
                        for i in range(extract_batch_length(preds_evaluated)):
                            yield {
                                key: value[i]
                                for key, value in six.iteritems(
                                    preds_evaluated)
                            }
Exemplo n.º 12
0
    def export_savedmodel(self,
                          export_dir_base,
                          serving_input_receiver_fn,
                          assets_extra=None,
                          as_text=False,
                          checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.
        This method builds a new graph by first calling the
        serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
        this `Estimator`'s model_fn to generate the model graph based on those
        features. It restores the given checkpoint (or, lacking that, the most
        recent checkpoint) into this graph in a fresh session.  Finally it creates
        a timestamped export directory below the given export_dir_base, and writes
        a `SavedModel` into it containing a single `MetaGraphDef` saved from this
        session.
        The exported `MetaGraphDef` will provide one `SignatureDef` for each
        element of the export_outputs dict returned from the model_fn, named using
        the same keys.  One of these keys is always
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
        signature will be served when a serving request does not specify one.
        For each signature, the outputs are provided by the corresponding
        `ExportOutput`s, and the inputs are always the input receivers provided by
        the serving_input_receiver_fn.
        Extra assets may be written into the SavedModel via the extra_assets
        argument.  This should be a dict, where each key gives a destination path
        (including the filename) relative to the assets.extra directory.  The
        corresponding value gives the full path of the source file to be copied.
        For example, the simple case of copying a single file without renaming it
        is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
        Args:
          export_dir_base: A string containing a directory in which to create
            timestamped subdirectories containing exported SavedModels.
          serving_input_receiver_fn: A function that takes no argument and
            returns a `ServingInputReceiver`.
          assets_extra: A dict specifying how to populate the assets.extra directory
            within the exported SavedModel, or `None` if no extra assets are needed.
          as_text: whether to write the SavedModel proto in text format.
          checkpoint_path: The checkpoint path to export.  If `None` (the default),
            the most recent checkpoint found within the model directory is chosen.
        Returns:
          The string path to the exported directory.
        Raises:
          ValueError: if no serving_input_receiver_fn is provided, no export_outputs
              are provided, or no checkpoint can be found.
        """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            training.get_or_create_global_step(g)
            random_seed.set_random_seed(self._config.tf_random_seed)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=model_fn_lib.ModeKeys.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)

            with tf_session.Session() as session:

                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
                    sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                # pylint: disable=protected-access
                local_init_op = (
                    estimator_spec.scaffold.local_init_op
                    or monitored_session.Scaffold._default_local_init_op())
                # pylint: enable=protected-access

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            return export_dir
Exemplo n.º 13
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            with ops.device('/cpu:0'):
                features, labels = input_fn()
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                                      saver.Saver(sharded=True,  # TODO `var_list`
                                                  max_to_keep=self._config.keep_checkpoint_max,
                                                  defer_build=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.StepCheckpointSaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepCheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.StepSummarySaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepSummarySaverHook(
                            scaffold=scaffold,
                            save_steps=self._config.save_summary_steps,
                            output_dir=self._model_dir,
                        )
                    ]

            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
            summary_io.SummaryWriterCache.clear()
            return loss
Exemplo n.º 14
0
    def export_savedmodel(self, export_dir_base, serving_input_receiver_fn, assets_extra=None,
                          as_text=False, checkpoint_path=None):
        """Exports inference graph as a SavedModel into given dir.
        This method builds a new graph by first calling the serving_input_receiver_fn to
        obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn
        to generate the model graph based on those features. It restores the given checkpoint
        (or, lacking that, the most recent checkpoint) into this graph in a fresh session.
        Finally it creates a timestamped export directory below the given export_dir_base,
        and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this
        session.
        The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the
        export_outputs dict returned from the model_fn, named using the same keys.
        One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
        indicating which signature will be served when a serving request does not specify one.
        For each signature, the outputs are provided by the corresponding `ExportOutput`s,
        and the inputs are always the input receivers provided by the serving_input_receiver_fn.
        Extra assets may be written into the SavedModel via the extra_assets argument.
        This should be a dict, where each key gives a destination path (including the filename)
        relative to the assets.extra directory.  The corresponding value gives the full path of
        the source file to be copied. For example, the simple case of copying a single file without
        renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.

        Args:
            export_dir_base: A string containing a directory in which to create
                timestamped subdirectories containing exported SavedModels.
            serving_input_receiver_fn: A function that takes no argument and
                returns a `ServingInputReceiver`.
            assets_extra: A dict specifying how to populate the assets.extra directory
                within the exported SavedModel, or `None` if no extra assets are needed.
            as_text: whether to write the SavedModel proto in text format.
            checkpoint_path: The checkpoint path to export.  If `None` (the default),
                the most recent checkpoint found within the model directory is chosen.
        Returns:
            The string path to the exported directory.
        Raises:
            ValueError: if no serving_input_receiver_fn is provided, no export_outputs
                are provided, or no checkpoint can be found.
        """
        if serving_input_receiver_fn is None:
            raise ValueError('serving_input_receiver_fn must be defined.')

        with ops.Graph().as_default() as g:
            training.get_or_create_global_step(g)
            random_seed.set_random_seed(self._config.tf_random_seed)
            serving_input_receiver = serving_input_receiver_fn()

            # Call the model_fn and collect the export_outputs.
            estimator_spec = self._call_model_fn(
                features=serving_input_receiver.features,
                labels=None,
                mode=Modes.PREDICT)

            # Build the SignatureDefs from receivers and all outputs
            signature_def_map = build_all_signature_defs(
                serving_input_receiver.receiver_tensors,
                estimator_spec.export_outputs)

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(self._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." % self._model_dir)

            export_dir = get_timestamped_export_dir(export_dir_base)

            with tf_session.Session() as session:
                saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(sharded=True)
                saver_for_restore.restore(session, checkpoint_path)
                local_init_op = (estimator_spec.scaffold.local_init_op or
                                 monitored_session.Scaffold._default_local_init_op())
                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=local_init_op)
                builder.save(as_text)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                                 compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                                 compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

            return export_dir
    def test_build_ensemble_spec(
            self,
            want_logits,
            want_loss=None,
            want_adanet_loss=None,
            want_ensemble_trainable_vars=None,
            adanet_lambda=0.,
            adanet_beta=0.,
            ensemble_spec_fn=lambda: None,
            use_bias=False,
            use_logits_last_layer=False,
            mixture_weight_type=MixtureWeightType.MATRIX,
            mixture_weight_initializer=tf_compat.v1.zeros_initializer(),
            warm_start_mixture_weights=True,
            subnetwork_builder_class=_Builder,
            mode=tf.estimator.ModeKeys.TRAIN,
            multi_head=False,
            want_subnetwork_trainable_vars=2,
            ensembler_class=ComplexityRegularizedEnsembler,
            my_ensemble_index=None,
            want_replay_indices=None,
            want_predictions=None,
            export_subnetworks=False,
            previous_ensemble_spec=None,
            previous_iteration_checkpoint=None):
        seed = 64

        if multi_head:
            head = multi_head_lib.MultiHead(heads=[
                binary_class_head.BinaryClassHead(
                    name="head1", loss_reduction=tf_compat.SUM),
                binary_class_head.BinaryClassHead(name="head2",
                                                  loss_reduction=tf_compat.SUM)
            ])
        else:
            head = binary_class_head.BinaryClassHead(
                loss_reduction=tf_compat.SUM)
        builder = _EnsembleBuilder(
            head=head,
            export_subnetwork_logits=export_subnetworks,
            export_subnetwork_last_layer=export_subnetworks)

        def _subnetwork_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_subnetwork_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("subnetwork_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        def _mixture_weights_train_op_fn(loss, var_list):
            self.assertLen(var_list, want_ensemble_trainable_vars)
            self.assertEqual(
                var_list,
                tf_compat.v1.get_collection(
                    tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES))
            # Subnetworks get iteration steps instead of global steps.
            self.assertEqual("ensemble_test/iteration_step",
                             tf_compat.v1.train.get_global_step().op.name)

            # Subnetworks get scoped summaries.
            self.assertEqual("fake_scalar",
                             tf_compat.v1.summary.scalar("scalar", 1.))
            self.assertEqual("fake_image",
                             tf_compat.v1.summary.image("image", 1.))
            self.assertEqual("fake_histogram",
                             tf_compat.v1.summary.histogram("histogram", 1.))
            self.assertEqual("fake_audio",
                             tf_compat.v1.summary.audio("audio", 1., 1.))
            if not var_list:
                return tf.no_op()
            optimizer = tf_compat.v1.train.GradientDescentOptimizer(
                learning_rate=.1)
            return optimizer.minimize(loss, var_list=var_list)

        previous_ensemble = None
        previous_ensemble_spec = ensemble_spec_fn()
        if previous_ensemble_spec:
            previous_ensemble = previous_ensemble_spec.ensemble

        subnetwork_manager = _SubnetworkManager(head)
        subnetwork_builder = subnetwork_builder_class(
            _subnetwork_train_op_fn,
            _mixture_weights_train_op_fn,
            use_logits_last_layer,
            seed,
            multi_head=multi_head)

        with tf.Graph().as_default() as g:
            tf_compat.v1.train.get_or_create_global_step()
            # A trainable variable to later verify that creating models does not
            # affect the global variables collection.
            _ = tf_compat.v1.get_variable("some_var", shape=0, trainable=True)

            features = {"x": tf.constant([[1.], [2.]])}
            if multi_head:
                labels = {
                    "head1": tf.constant([0, 1]),
                    "head2": tf.constant([0, 1])
                }
            else:
                labels = tf.constant([0, 1])

            session_config = tf.compat.v1.ConfigProto(
                gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))

            subnetwork_spec = subnetwork_manager.build_subnetwork_spec(
                name="test",
                subnetwork_builder=subnetwork_builder,
                summary=_FakeSummary(),
                features=features,
                mode=mode,
                labels=labels,
                previous_ensemble=previous_ensemble)
            ensembler_kwargs = {}
            if ensembler_class is ComplexityRegularizedEnsembler:
                ensembler_kwargs.update({
                    "mixture_weight_type": mixture_weight_type,
                    "mixture_weight_initializer": mixture_weight_initializer,
                    "warm_start_mixture_weights": warm_start_mixture_weights,
                    "model_dir": self.test_subdirectory,
                    "adanet_lambda": adanet_lambda,
                    "adanet_beta": adanet_beta,
                    "use_bias": use_bias
                })
            if ensembler_class is MeanEnsembler:
                ensembler_kwargs.update(
                    {"add_mean_last_layer_predictions": True})
            ensemble_spec = builder.build_ensemble_spec(
                # Note: when ensemble_spec is not None and warm_start_mixture_weights
                # is True, we need to make sure that the bias and mixture weights are
                # already saved to the checkpoint_dir.
                name="test",
                previous_ensemble_spec=previous_ensemble_spec,
                candidate=EnsembleCandidate("foo", [subnetwork_builder], None),
                ensembler=ensembler_class(**ensembler_kwargs),
                subnetwork_specs=[subnetwork_spec],
                summary=_FakeSummary(),
                features=features,
                iteration_number=1,
                labels=labels,
                my_ensemble_index=my_ensemble_index,
                mode=mode,
                previous_iteration_checkpoint=previous_iteration_checkpoint)

            if want_replay_indices:
                self.assertAllEqual(want_replay_indices,
                                    ensemble_spec.architecture.replay_indices)

            with tf_compat.v1.Session(
                    graph=g, config=session_config).as_default() as sess:
                sess.run(tf_compat.v1.global_variables_initializer())

                # Equals the number of subnetwork and ensemble trainable variables,
                # plus the one 'some_var' created earlier.
                self.assertLen(
                    tf_compat.v1.trainable_variables(),
                    want_subnetwork_trainable_vars +
                    want_ensemble_trainable_vars + 1)

                # Get the real global step outside a subnetwork's context.
                self.assertEqual("global_step",
                                 tf_compat.v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 tf_v1.train.get_global_step().op.name)
                self.assertEqual("global_step",
                                 training_util.get_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_compat.v1.train.get_or_create_global_step().op.name)
                self.assertEqual("global_step",
                                 train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    tf_v1.train.get_or_create_global_step().op.name)
                self.assertEqual(
                    "global_step",
                    training_util.get_or_create_global_step().op.name)

                # Get global tf.summary outside a subnetwork's context.
                self.assertNotEqual("fake_scalar",
                                    tf_compat.v1.summary.scalar("scalar", 1.))
                self.assertNotEqual("fake_image",
                                    tf_compat.v1.summary.image("image", 1.))
                self.assertNotEqual(
                    "fake_histogram",
                    tf_compat.v1.summary.histogram("histogram", 1.))
                self.assertNotEqual(
                    "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.))

                if mode == tf.estimator.ModeKeys.PREDICT:
                    self.assertAllClose(want_logits,
                                        sess.run(
                                            ensemble_spec.ensemble.logits),
                                        atol=1e-3)
                    self.assertIsNone(ensemble_spec.loss)
                    self.assertIsNone(ensemble_spec.adanet_loss)
                    self.assertIsNone(ensemble_spec.train_op)
                    self.assertIsNotNone(ensemble_spec.export_outputs)
                    if not export_subnetworks:
                        return
                    if not multi_head:
                        subnetwork_logits = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LOGITS_EXPORT_SIGNATURE].outputs)
                        self.assertAllClose(
                            subnetwork_logits["test"],
                            sess.run(subnetwork_spec.subnetwork.logits))
                        subnetwork_last_layer = sess.run(
                            ensemble_spec.export_outputs[
                                _EnsembleBuilder.
                                _SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE].
                            outputs)
                        self.assertAllClose(
                            subnetwork_last_layer["test"],
                            sess.run(subnetwork_spec.subnetwork.last_layer))
                    else:
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_logits_head1 = sess.run(
                            ensemble_spec.
                            export_outputs["subnetwork_logits_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_logits_head1["test"],
                            sess.run(
                                subnetwork_spec.subnetwork.logits["head1"]))
                        self.assertIn("subnetwork_logits_head2",
                                      ensemble_spec.export_outputs)
                        subnetwork_last_layer_head1 = sess.run(
                            ensemble_spec.export_outputs[
                                "subnetwork_last_layer_head1"].outputs)
                        self.assertAllClose(
                            subnetwork_last_layer_head1["test"],
                            sess.run(subnetwork_spec.subnetwork.
                                     last_layer["head1"]))
                    return

                # Verify that train_op works, previous loss should be greater than loss
                # after a train op.
                loss = sess.run(ensemble_spec.loss)
                train_op = tf.group(subnetwork_spec.train_op.train_op,
                                    ensemble_spec.train_op.train_op)
                for _ in range(3):
                    sess.run(train_op)
                self.assertGreater(loss, sess.run(ensemble_spec.loss))

                self.assertAllClose(want_logits,
                                    sess.run(ensemble_spec.ensemble.logits),
                                    atol=1e-3)

                if ensembler_class is ComplexityRegularizedEnsembler:
                    # Bias should learn a non-zero value when used.
                    bias = sess.run(ensemble_spec.ensemble.bias)
                    if isinstance(bias, dict):
                        bias = sum(abs(b) for b in bias.values())
                    if use_bias:
                        self.assertNotEqual(0., bias)
                    else:
                        self.assertAlmostEqual(0., bias)

                self.assertAlmostEqual(want_loss,
                                       sess.run(ensemble_spec.loss),
                                       places=3)
                self.assertAlmostEqual(want_adanet_loss,
                                       sess.run(ensemble_spec.adanet_loss),
                                       places=3)

                if want_predictions:
                    self.assertAllClose(
                        want_predictions,
                        sess.run(ensemble_spec.ensemble.predictions),
                        atol=1e-3)
Exemplo n.º 16
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        with ops.Graph().as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, Modes.TRAIN)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,
                        max_to_keep=self._config.keep_checkpoint_max,
                        keep_checkpoint_every_n_hours=(
                            self._config.keep_checkpoint_every_n_hours),
                        defer_build=True,
                        save_relative_paths=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.StepCheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepCheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.StepSummarySaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepSummarySaverHook(
                            scaffold=scaffold,
                            save_steps=self._config.save_summary_steps,
                            output_dir=self._model_dir,
                        )
                    ]

            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks +
                    list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=
                    0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=
                    0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            return loss
Exemplo n.º 17
0
 def _global_step(self):
     return training.get_or_create_global_step()
 def bad_input_fn():
   training.get_or_create_global_step()
   return dataset_ops.Dataset.from_tensors((
       {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
       constant_op.constant([[1], [1]], dtype=dtypes.float32)))
Exemplo n.º 19
0
    def _train_model(self, env, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            global_episode = get_or_create_global_episode(g)
            global_timestep = get_or_create_global_timestep(g)
            update_episode_op = tf.assign_add(global_episode, 1)
            update_timestep_op = tf.assign_add(global_timestep, 1)
            no_run_hooks = tf.no_op(name='no_run_hooks')
            with ops.device('/cpu:0'):
                features, labels = self._prepare_input_fn(Modes.TRAIN, env)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                                      saver.Saver(sharded=True,  # TODO `var_list`
                                                  max_to_keep=self._config.keep_checkpoint_max,
                                                  defer_build=True))

            chief_hooks = [
                plx_hooks.EpisodeLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'global_timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_episodes=100),  # TODO: save every episode?
                plx_hooks.EpisodeCounterHook(output_dir=self.model_dir)
            ]
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.EpisodeCheckpointSaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeCheckpointSaverHook(
                            self._model_dir,
                            save_episodes=100,  # TODO: save every episode?
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any(
                    [isinstance(h, plx_hooks.EpisodeSummarySaverHook)
                     for h in (all_hooks +
                               chief_hooks +
                               list(estimator_spec.training_chief_hooks))])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeSummarySaverHook(
                            scaffold=scaffold,
                            save_episodes=100,  # TODO: save every episode?
                            output_dir=self._model_dir,
                        )
                    ]
            with monitored_session.MonitoredTrainingSession(
                master=self._config.master,
                is_chief=self._config.is_chief,
                checkpoint_dir=self._model_dir,
                scaffold=scaffold,
                hooks=all_hooks,
                chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks),
                save_checkpoint_secs=0,  # Saving checkpoint is handled by a hook.
                save_summaries_steps=0,  # Saving summaries is handled by a hook.
                config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    loss = self.run_episode(
                        env=env,
                        sess=mon_sess,
                        features=features,
                        labels=labels,
                        no_run_hooks=no_run_hooks,
                        global_step=global_step,
                        update_episode_op=update_episode_op,
                        update_timestep_op=update_timestep_op,
                        estimator_spec=estimator_spec)
            summary_io.SummaryWriterCache.clear()
            return loss
Exemplo n.º 20
0
 def bad_input_fn():
     training.get_or_create_global_step()
     return dataset_ops.Dataset.from_tensors(({
         'x':
         constant_op.constant([[1], [1]], dtype=dtypes.int64)
     }, constant_op.constant([[1], [1]], dtype=dtypes.float32)))
Exemplo n.º 21
0
    def _train_model(self, env, first_update, update_frequency, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            global_episode = get_or_create_global_episode(g)
            global_timestep = get_or_create_global_timestep(g)
            update_episode_op = tf.assign_add(global_episode, 1)
            update_timestep_op = tf.assign_add(global_timestep, 1)
            no_run_hooks = tf.no_op(name='no_run_hooks')
            with ops.device('/cpu:0'):
                features, labels = self._prepare_input_fn(Modes.TRAIN, env)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = [
                plx_hooks.EpisodeLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'global_timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_episodes=1),  # TODO: save every episode?
                plx_hooks.EpisodeCounterHook(output_dir=self.model_dir)
            ]
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeCheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeCheckpointSaverHook(
                            self._model_dir,
                            save_episodes=1,  # TODO: save every episode?
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeSummarySaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeSummarySaverHook(
                            scaffold=scaffold,
                            save_episodes=1,  # TODO: save every episode?
                            output_dir=self._model_dir,
                        )
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks +
                    list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=
                    0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=
                    0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    loss = self.run_episode(
                        env=env,
                        sess=mon_sess,
                        features=features,
                        labels=labels,
                        no_run_hooks=no_run_hooks,
                        global_step=global_step,
                        update_episode_op=update_episode_op,
                        update_timestep_op=update_timestep_op,
                        first_update=first_update,
                        update_frequency=update_frequency,
                        estimator_spec=estimator_spec)
            summary_io.SummaryWriterCache.clear()
            return loss