Beispiel #1
0
    def test_task_graph(self):

        with tf.Graph().as_default():
            with tf.compat.v1.Session() as sess:
                query = tf.constant([[1, 2, 3], [2, 3, 4]], dtype=tf.float32)
                candidate = tf.constant([[1, 1, 1], [1, 1, 0]],
                                        dtype=tf.float32)
                candidate_dataset = tf.data.Dataset.from_tensor_slices(
                    np.array([[0, 0, 0]] * 20, dtype=np.float32))

                task = retrieval.Retrieval(metrics=metrics.FactorizedTopK(
                    candidates=candidate_dataset.batch(16),
                    metrics=[
                        tf.keras.metrics.TopKCategoricalAccuracy(
                            k=5, name="factorized_categorical_accuracy_at_5")
                    ]))

                expected_metrics = {
                    "factorized_categorical_accuracy_at_5": 1.0,
                }

                loss = task(query_embeddings=query,
                            candidate_embeddings=candidate)

                sess.run([var.initializer for var in task.variables])
                for metric in task.metrics:
                    sess.run([var.initializer for var in metric.variables])
                sess.run(loss)

                metrics_ = {
                    metric.name: sess.run(metric.result())
                    for metric in task.metrics
                }

                self.assertAllClose(expected_metrics, metrics_)
Beispiel #2
0
    def test_task(self):

        query = tf.constant([[1, 2, 3], [2, 3, 4]], dtype=tf.float32)
        candidate = tf.constant([[1, 1, 1], [1, 1, 0]], dtype=tf.float32)
        candidate_dataset = tf.data.Dataset.from_tensor_slices(
            np.array([[0, 0, 0]] * 20, dtype=np.float32))

        task = retrieval.Retrieval(metrics=metrics.FactorizedTopK(
            candidates=candidate_dataset.batch(16),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(
                    k=5, name="factorized_categorical_accuracy_at_5")
            ]))

        # All_pair_scores: [[6, 3], [9, 5]].
        # Normalized logits: [[3, 0], [4, 0]].
        expected_loss = -np.log(_sigmoid(3.0)) - np.log(1 - _sigmoid(4.0))
        expected_metrics = {
            "factorized_categorical_accuracy_at_5": 1.0,
        }

        loss = task(query_embeddings=query, candidate_embeddings=candidate)
        metrics_ = {
            metric.name: metric.result().numpy()
            for metric in task.metrics
        }

        self.assertIsNotNone(loss)
        self.assertAllClose(expected_loss, loss)
        self.assertAllClose(expected_metrics, metrics_)