Esempio n. 1
0
 def test_reshape_to_2d(self):
     tensor_3d = tf.constant([[[1], [2], [3]], [[4], [5], [6]]])
     tensor_3d_reshaped = utils.reshape_to_2d(tensor_3d)
     tensor_1d = tf.constant([1, 2, 3])
     tensor_1d_reshaped = utils.reshape_to_2d(tensor_1d)
     with tf.compat.v1.Session() as sess:
         self.assertAllEqual(sess.run(tensor_3d_reshaped),
                             [[1, 2, 3], [4, 5, 6]])
         self.assertAllEqual(sess.run(tensor_1d_reshaped), [[1], [2], [3]])
 def _get_weights(features):
     """Get weights tensor from features and reshape it to 2-D if necessary."""
     weights = None
     if weights_feature_name:
         weights = tf.convert_to_tensor(features[weights_feature_name])
         # Convert weights to a 2-D Tensor.
         weights = utils.reshape_to_2d(weights)
     return weights
    def _loss_fn(labels, logits, features):
        """Computes a single loss or weighted combination of losses.

    Args:
      labels: A `Tensor` of the same shape as `logits` representing relevance.
      logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
        ranking score of the corresponding item.
      features: Dict of Tensors of shape [batch_size, list_size, ...] for
        per-example features and shape [batch_size, ...] for non-example context
        features.

    Returns:
      An op for a single loss or weighted combination of multiple losses.

    Raises:
      ValueError: If `loss_keys` is invalid.
    """
        weights = None
        if weights_feature_name:
            weights = tf.convert_to_tensor(features[weights_feature_name])
            # Convert weights to a 2-D Tensor.
            weights = utils.reshape_to_2d(weights)

        loss_kwargs = {
            'labels': labels,
            'logits': logits,
            'weights': weights,
            'reduction': reduction,
            'name': name,
        }
        if extra_args is not None:
            loss_kwargs.update(extra_args)

        loss_kwargs_with_lambda_weight = loss_kwargs.copy()
        loss_kwargs_with_lambda_weight['lambda_weight'] = lambda_weight

        loss_kwargs_with_lambda_weight_and_seed = (
            loss_kwargs_with_lambda_weight.copy())
        loss_kwargs_with_lambda_weight_and_seed['seed'] = seed

        key_to_fn = {
            RankingLossKey.PAIRWISE_HINGE_LOSS:
            (_pairwise_hinge_loss, loss_kwargs_with_lambda_weight),
            RankingLossKey.PAIRWISE_LOGISTIC_LOSS:
            (_pairwise_logistic_loss, loss_kwargs_with_lambda_weight),
            RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS:
            (_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight),
            RankingLossKey.SOFTMAX_LOSS: (_softmax_loss,
                                          loss_kwargs_with_lambda_weight),
            RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS:
            (_sigmoid_cross_entropy_loss, loss_kwargs),
            RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss,
                                               loss_kwargs),
            RankingLossKey.LIST_MLE_LOSS:
            (_list_mle_loss, loss_kwargs_with_lambda_weight_and_seed),
            RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs),
        }

        # Obtain the list of loss ops.
        loss_ops = []
        for loss_key in loss_keys:
            if loss_key not in key_to_fn:
                raise ValueError('Invalid loss_key: {}.'.format(loss_key))
            loss_fn, kwargs = key_to_fn[loss_key]
            loss_ops.append(loss_fn(**kwargs))

        # Compute weighted combination of losses.
        if loss_weights:
            weighted_losses = []
            for loss_op, loss_weight in zip(loss_ops, loss_weights):
                weighted_losses.append(tf.multiply(loss_op, loss_weight))
        else:
            weighted_losses = loss_ops

        return tf.add_n(weighted_losses)
Esempio n. 4
0
    def _model_fn(features, labels, mode, params, config):
        """Defines an `Estimator` `model_fn`."""
        del [config, params]

        # In Estimator, all sub-graphs need to be constructed inside the model_fn.
        # Hence, ranker, losses, metrics and optimizer are cloned inside this
        # function.
        ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn)
        training = (mode == tf.compat.v1.estimator.ModeKeys.TRAIN)

        weights = None
        if weights_feature_name and mode != tf.compat.v1.estimator.ModeKeys.PREDICT:
            if weights_feature_name not in features:
                raise ValueError(
                    "weights_feature '{0}' can not be found in 'features'.".
                    format(weights_feature_name))
            else:
                weights = utils.reshape_to_2d(
                    features.pop(weights_feature_name))

        logits = ranker(features, training=training)

        if mode == tf.compat.v1.estimator.ModeKeys.PREDICT:
            return tf.compat.v1.estimator.EstimatorSpec(mode=mode,
                                                        predictions=logits)

        loss = _clone_fn(model.loss)
        total_loss = loss(labels, logits, sample_weight=weights)

        keras_metrics = []
        for metric in model.metrics:
            keras_metrics.append(_clone_fn(metric))
        # Adding default metrics here as model.metrics does not contain custom
        # metrics.
        keras_metrics += metrics.default_keras_metrics()
        eval_metric_ops = {}
        for keras_metric in keras_metrics:
            keras_metric.update_state(labels, logits, sample_weight=weights)
            eval_metric_ops[keras_metric.name] = keras_metric

        train_op = None
        if training:
            optimizer = _clone_fn(model.optimizer)
            optimizer.iterations = tf.compat.v1.train.get_or_create_global_step(
            )
            # Get both the unconditional updates (the None part)
            # and the input-conditional updates (the features part).
            # These updates are for layers like BatchNormalization, which have
            # separate update and minimize ops.
            update_ops = ranker.get_updates_for(None) + ranker.get_updates_for(
                features)
            minimize_op = optimizer.get_updates(
                loss=total_loss, params=ranker.trainable_variables)[0]
            train_op = tf.group(minimize_op, *update_ops)

        return tf.compat.v1.estimator.EstimatorSpec(
            mode=mode,
            predictions=logits,
            loss=total_loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops)