Esempio n. 1
0
        def _model_fn(features, labels, params, mode, config):
            """model_fn implementation."""
            if mode == tf.estimator.ModeKeys.TRAIN:
                self._register_word_counters(features, labels)

                features_shards = dispatcher.shard(features)
                labels_shards = dispatcher.shard(labels)

                with tf.variable_scope(self.name,
                                       initializer=self._initializer(params)):
                    losses_shards = dispatcher(_loss_op, features_shards,
                                               labels_shards, params, mode,
                                               config)

                loss = _extract_loss(losses_shards)
                train_op = optimize(loss, params)
                return tf.estimator.EstimatorSpec(mode,
                                                  loss=loss,
                                                  train_op=train_op)
            elif mode == tf.estimator.ModeKeys.EVAL:
                with tf.variable_scope(self.name):
                    logits, predictions = self._build(features,
                                                      labels,
                                                      params,
                                                      mode,
                                                      config=config)
                    loss = self._compute_loss(features, labels, logits, params,
                                              mode)

                loss = _extract_loss(loss)
                eval_metric_ops = self._compute_metrics(
                    features, labels, predictions)
                if predictions is not None:
                    # Register predictions in a collection so that hooks can easily fetch them.
                    add_dict_to_collection("predictions", predictions)

                return tf.estimator.EstimatorSpec(
                    mode, loss=loss, eval_metric_ops=eval_metric_ops)
            elif mode == tf.estimator.ModeKeys.PREDICT:
                with tf.variable_scope(self.name):
                    _, predictions = self._build(features,
                                                 labels,
                                                 params,
                                                 mode,
                                                 config=config)

                export_outputs = {}
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       predictions))

                return tf.estimator.EstimatorSpec(
                    mode,
                    predictions=predictions,
                    export_outputs=export_outputs)
            else:
                raise RuntimeError("Invalid mode")
Esempio n. 2
0
    def __call__(self, features, labels, params, mode, config):
        """Creates the model.

    See Also:
      ``tf.estimator.Estimator`` 's ``model_fn`` argument for more details about
      arguments and the returned value.
    """
        if mode == tf.estimator.ModeKeys.TRAIN:
            self._register_word_counters(features, labels)

        with tf.variable_scope(
                self.name,
                initializer=self._initializer(params)) as model_scope:
            outputs, predictions = self._build(features, labels, params, mode,
                                               config)

        if predictions is not None:
            # Register predictions in a collection so that hooks can easily fetch them.
            add_dict_to_collection("predictions", predictions)

        if mode != tf.estimator.ModeKeys.PREDICT:
            with tf.variable_scope(model_scope):
                loss = self._compute_loss(features, labels, outputs, params,
                                          mode)

            if isinstance(loss, tuple):
                loss, display_loss = loss
            else:
                display_loss = loss

            tf.summary.scalar("loss", display_loss)

            if mode == tf.estimator.ModeKeys.TRAIN:
                train_op = optimize(loss, params)
                return tf.estimator.EstimatorSpec(mode,
                                                  loss=loss,
                                                  train_op=train_op)
            else:
                eval_metric_ops = self._compute_metrics(
                    features, labels, predictions)
                return tf.estimator.EstimatorSpec(
                    mode, loss=loss, eval_metric_ops=eval_metric_ops)
        else:
            export_outputs = {}
            export_outputs[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
                tf.estimator.export.PredictOutput(predictions)
            return tf.estimator.EstimatorSpec(mode,
                                              predictions=predictions,
                                              export_outputs=export_outputs)