def test_reshape_to_2d(self): tensor_3d = tf.constant([[[1], [2], [3]], [[4], [5], [6]]]) tensor_3d_reshaped = utils.reshape_to_2d(tensor_3d) tensor_1d = tf.constant([1, 2, 3]) tensor_1d_reshaped = utils.reshape_to_2d(tensor_1d) with tf.compat.v1.Session() as sess: self.assertAllEqual(sess.run(tensor_3d_reshaped), [[1, 2, 3], [4, 5, 6]]) self.assertAllEqual(sess.run(tensor_1d_reshaped), [[1], [2], [3]])
def _get_weights(features): """Get weights tensor from features and reshape it to 2-D if necessary.""" weights = None if weights_feature_name: weights = tf.convert_to_tensor(features[weights_feature_name]) # Convert weights to a 2-D Tensor. weights = utils.reshape_to_2d(weights) return weights
def _loss_fn(labels, logits, features): """Computes a single loss or weighted combination of losses. Args: labels: A `Tensor` of the same shape as `logits` representing relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. features: Dict of Tensors of shape [batch_size, list_size, ...] for per-example features and shape [batch_size, ...] for non-example context features. Returns: An op for a single loss or weighted combination of multiple losses. Raises: ValueError: If `loss_keys` is invalid. """ weights = None if weights_feature_name: weights = tf.convert_to_tensor(features[weights_feature_name]) # Convert weights to a 2-D Tensor. weights = utils.reshape_to_2d(weights) loss_kwargs = { 'labels': labels, 'logits': logits, 'weights': weights, 'reduction': reduction, 'name': name, } if extra_args is not None: loss_kwargs.update(extra_args) loss_kwargs_with_lambda_weight = loss_kwargs.copy() loss_kwargs_with_lambda_weight['lambda_weight'] = lambda_weight loss_kwargs_with_lambda_weight_and_seed = ( loss_kwargs_with_lambda_weight.copy()) loss_kwargs_with_lambda_weight_and_seed['seed'] = seed key_to_fn = { RankingLossKey.PAIRWISE_HINGE_LOSS: (_pairwise_hinge_loss, loss_kwargs_with_lambda_weight), RankingLossKey.PAIRWISE_LOGISTIC_LOSS: (_pairwise_logistic_loss, loss_kwargs_with_lambda_weight), RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS: (_pairwise_soft_zero_one_loss, loss_kwargs_with_lambda_weight), RankingLossKey.SOFTMAX_LOSS: (_softmax_loss, loss_kwargs_with_lambda_weight), RankingLossKey.SIGMOID_CROSS_ENTROPY_LOSS: (_sigmoid_cross_entropy_loss, loss_kwargs), RankingLossKey.MEAN_SQUARED_LOSS: (_mean_squared_loss, loss_kwargs), RankingLossKey.LIST_MLE_LOSS: (_list_mle_loss, loss_kwargs_with_lambda_weight_and_seed), RankingLossKey.APPROX_NDCG_LOSS: (_approx_ndcg_loss, loss_kwargs), } # Obtain the list of loss ops. loss_ops = [] for loss_key in loss_keys: if loss_key not in key_to_fn: raise ValueError('Invalid loss_key: {}.'.format(loss_key)) loss_fn, kwargs = key_to_fn[loss_key] loss_ops.append(loss_fn(**kwargs)) # Compute weighted combination of losses. if loss_weights: weighted_losses = [] for loss_op, loss_weight in zip(loss_ops, loss_weights): weighted_losses.append(tf.multiply(loss_op, loss_weight)) else: weighted_losses = loss_ops return tf.add_n(weighted_losses)
def _model_fn(features, labels, mode, params, config): """Defines an `Estimator` `model_fn`.""" del [config, params] # In Estimator, all sub-graphs need to be constructed inside the model_fn. # Hence, ranker, losses, metrics and optimizer are cloned inside this # function. ranker = tf.keras.models.clone_model(model, clone_function=_clone_fn) training = (mode == tf.compat.v1.estimator.ModeKeys.TRAIN) weights = None if weights_feature_name and mode != tf.compat.v1.estimator.ModeKeys.PREDICT: if weights_feature_name not in features: raise ValueError( "weights_feature '{0}' can not be found in 'features'.". format(weights_feature_name)) else: weights = utils.reshape_to_2d( features.pop(weights_feature_name)) logits = ranker(features, training=training) if mode == tf.compat.v1.estimator.ModeKeys.PREDICT: return tf.compat.v1.estimator.EstimatorSpec(mode=mode, predictions=logits) loss = _clone_fn(model.loss) total_loss = loss(labels, logits, sample_weight=weights) keras_metrics = [] for metric in model.metrics: keras_metrics.append(_clone_fn(metric)) # Adding default metrics here as model.metrics does not contain custom # metrics. keras_metrics += metrics.default_keras_metrics() eval_metric_ops = {} for keras_metric in keras_metrics: keras_metric.update_state(labels, logits, sample_weight=weights) eval_metric_ops[keras_metric.name] = keras_metric train_op = None if training: optimizer = _clone_fn(model.optimizer) optimizer.iterations = tf.compat.v1.train.get_or_create_global_step( ) # Get both the unconditional updates (the None part) # and the input-conditional updates (the features part). # These updates are for layers like BatchNormalization, which have # separate update and minimize ops. update_ops = ranker.get_updates_for(None) + ranker.get_updates_for( features) minimize_op = optimizer.get_updates( loss=total_loss, params=ranker.trainable_variables)[0] train_op = tf.group(minimize_op, *update_ops) return tf.compat.v1.estimator.EstimatorSpec( mode=mode, predictions=logits, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metric_ops)