def test_task_graph(self):
        with tf.Graph().as_default():
            with tf.compat.v1.Session() as sess:
                task = ranking.Ranking(
                    metrics=[tf.keras.metrics.BinaryAccuracy(name="accuracy")],
                    label_metrics=[tf.keras.metrics.Mean(name="label_mean")],
                    prediction_metrics=[
                        tf.keras.metrics.Mean(name="prediction_mean")
                    ])
                predictions = tf.constant([[1], [0.3]], dtype=tf.float32)
                labels = tf.constant([[1], [1]], dtype=tf.float32)

                expected_metrics = {
                    "accuracy": 0.5,
                    "label_mean": 1.0,
                    "prediction_mean": 0.65
                }

                loss = task(predictions=predictions, labels=labels)

                sess.run([var.initializer for var in task.variables])
                sess.run(loss)

                metrics = {
                    metric.name: sess.run(metric.result())
                    for metric in task.metrics
                }

        self.assertAllClose(expected_metrics, metrics)
    def test_task(self):

        task = ranking.Ranking(
            metrics=[tf.keras.metrics.BinaryAccuracy(name="accuracy")],
            label_metrics=[tf.keras.metrics.Mean(name="label_mean")],
            prediction_metrics=[tf.keras.metrics.Mean(name="prediction_mean")])

        predictions = tf.constant([[1], [0.3]], dtype=tf.float32)
        labels = tf.constant([[1], [1]], dtype=tf.float32)

        # Standard log loss formula.
        expected_loss = -(math.log(1) + math.log(0.3)) / 2.0
        expected_metrics = {
            "accuracy": 0.5,
            "label_mean": 1.0,
            "prediction_mean": 0.65
        }

        loss = task(predictions=predictions, labels=labels)
        metrics = {
            metric.name: metric.result().numpy()
            for metric in task.metrics
        }

        self.assertIsNotNone(loss)
        self.assertAllClose(expected_loss, loss)
        self.assertAllClose(expected_metrics, metrics)