예제 #1
0
    def setUp(self):
        super(KerasModelToEstimatorTest, self).setUp()
        (context_feature_columns, example_feature_columns,
         custom_objects) = _get_feature_columns()
        self._context_feature_columns = context_feature_columns
        self._example_feature_columns = example_feature_columns
        # Remove label feature from example feature column.
        del self._example_feature_columns[_LABEL_FEATURE]

        self._custom_objects = custom_objects
        self._network = _DummyUnivariateRankingNetwork(
            context_feature_columns=self._context_feature_columns,
            example_feature_columns=self._example_feature_columns)
        self._loss = losses.get(
            losses.RankingLossKey.SOFTMAX_LOSS,
            reduction=tf.compat.v2.losses.Reduction.SUM_OVER_BATCH_SIZE)
        self._eval_metrics = metrics.default_keras_metrics()
        self._optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
        self._config = tf.estimator.RunConfig(keep_checkpoint_max=2,
                                              save_checkpoints_secs=2)

        self._data_file = os.path.join(tf.compat.v1.test.get_temp_dir(),
                                       'test_elwc.tfrecord')
        serialized_elwc_list = [
            _ELWC_PROTO.SerializeToString(),
        ] * 20
        if tf.io.gfile.exists(self._data_file):
            tf.io.gfile.remove(self._data_file)
        with tf.io.TFRecordWriter(self._data_file) as writer:
            for serialized_elwc in serialized_elwc_list:
                writer.write(serialized_elwc)
예제 #2
0
 def setUp(self):
   super(KerasModelToEstimatorTest, self).setUp()
   self.context_feature_columns = _context_feature_columns()
   self.example_feature_columns = _example_feature_columns()
   self.features = _features()
   self.network = _DummyUnivariateRankingNetwork(
       context_feature_columns=self.context_feature_columns,
       example_feature_columns=self.example_feature_columns)
   self.loss = losses.get(
       losses.RankingLossKey.SOFTMAX_LOSS,
       reduction=tf.compat.v2.losses.Reduction.SUM_OVER_BATCH_SIZE)
   self.eval_metrics = metrics.default_keras_metrics()
   self.optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
   self.config = tf.estimator.RunConfig(
       keep_checkpoint_max=2, save_checkpoints_secs=2)
예제 #3
0
 def test_default_keras_metrics(self):
     default_metrics = metrics_lib.default_keras_metrics()
     self.assertLen(default_metrics, 11)
     for metric in default_metrics:
         self.assertIsInstance(metric, tf.keras.metrics.Metric)
예제 #4
0
    def _model_fn(features, labels, mode, params, config):
        """Defines an `Estimator` `model_fn`."""
        del [config, params]

        # In Estimator, all sub-graphs need to be constructed inside the model_fn.
        # Hence, ranker, losses, metrics and optimizer are cloned inside this
        # function.
        ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn)
        training = (mode == tf.compat.v1.estimator.ModeKeys.TRAIN)

        weights = None
        if weights_feature_name and mode != tf.compat.v1.estimator.ModeKeys.PREDICT:
            if weights_feature_name not in features:
                raise ValueError(
                    "weights_feature '{0}' can not be found in 'features'.".
                    format(weights_feature_name))
            else:
                weights = utils.reshape_to_2d(
                    features.pop(weights_feature_name))

        logits = ranker(features, training=training)

        if mode == tf.compat.v1.estimator.ModeKeys.PREDICT:
            return tf.compat.v1.estimator.EstimatorSpec(mode=mode,
                                                        predictions=logits)

        loss = _clone_fn(model.loss)
        total_loss = loss(labels, logits, sample_weight=weights)

        keras_metrics = []
        for metric in model.metrics:
            keras_metrics.append(_clone_fn(metric))
        # Adding default metrics here as model.metrics does not contain custom
        # metrics.
        keras_metrics += metrics.default_keras_metrics()
        eval_metric_ops = {}
        for keras_metric in keras_metrics:
            keras_metric.update_state(labels, logits, sample_weight=weights)
            eval_metric_ops[keras_metric.name] = keras_metric

        train_op = None
        if training:
            optimizer = _clone_fn(model.optimizer)
            optimizer.iterations = tf.compat.v1.train.get_or_create_global_step(
            )
            # Get both the unconditional updates (the None part)
            # and the input-conditional updates (the features part).
            # These updates are for layers like BatchNormalization, which have
            # separate update and minimize ops.
            update_ops = ranker.get_updates_for(None) + ranker.get_updates_for(
                features)
            minimize_op = optimizer.get_updates(
                loss=total_loss, params=ranker.trainable_variables)[0]
            train_op = tf.group(minimize_op, *update_ops)

        return tf.compat.v1.estimator.EstimatorSpec(
            mode=mode,
            predictions=logits,
            loss=total_loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops)