def test_task_graph(self): with tf.Graph().as_default(): with tf.compat.v1.Session() as sess: query = tf.constant([[1, 2, 3], [2, 3, 4]], dtype=tf.float32) candidate = tf.constant([[1, 1, 1], [1, 1, 0]], dtype=tf.float32) candidate_dataset = tf.data.Dataset.from_tensor_slices( np.array([[0, 0, 0]] * 20, dtype=np.float32)) task = retrieval.Retrieval(metrics=metrics.FactorizedTopK( candidates=candidate_dataset.batch(16), metrics=[ tf.keras.metrics.TopKCategoricalAccuracy( k=5, name="factorized_categorical_accuracy_at_5") ])) expected_metrics = { "factorized_categorical_accuracy_at_5": 1.0, } loss = task(query_embeddings=query, candidate_embeddings=candidate) sess.run([var.initializer for var in task.variables]) for metric in task.metrics: sess.run([var.initializer for var in metric.variables]) sess.run(loss) metrics_ = { metric.name: sess.run(metric.result()) for metric in task.metrics } self.assertAllClose(expected_metrics, metrics_)
def test_task(self): query = tf.constant([[1, 2, 3], [2, 3, 4]], dtype=tf.float32) candidate = tf.constant([[1, 1, 1], [1, 1, 0]], dtype=tf.float32) candidate_dataset = tf.data.Dataset.from_tensor_slices( np.array([[0, 0, 0]] * 20, dtype=np.float32)) task = retrieval.Retrieval(metrics=metrics.FactorizedTopK( candidates=candidate_dataset.batch(16), metrics=[ tf.keras.metrics.TopKCategoricalAccuracy( k=5, name="factorized_categorical_accuracy_at_5") ])) # All_pair_scores: [[6, 3], [9, 5]]. # Normalized logits: [[3, 0], [4, 0]]. expected_loss = -np.log(_sigmoid(3.0)) - np.log(1 - _sigmoid(4.0)) expected_metrics = { "factorized_categorical_accuracy_at_5": 1.0, } loss = task(query_embeddings=query, candidate_embeddings=candidate) metrics_ = { metric.name: metric.result().numpy() for metric in task.metrics } self.assertIsNotNone(loss) self.assertAllClose(expected_loss, loss) self.assertAllClose(expected_metrics, metrics_)