def setup_metrics(item_dataset, item_model): def item_map(batched_items): return tf.squeeze(item_model(tf.expand_dims(batched_items, axis=1))) candidates = item_dataset.batch(500).map(item_map) metrics_k = map(int, FLAGS.metrics_k) factorized_metrics = [] for x in metrics_k: factorized_metrics.append( tf.keras.metrics.TopKCategoricalAccuracy(k=x, name=f'HR@{x}')) factorized_metrics.append( tf.keras.metrics.TopKCategoricalAccuracy(k=x, name=f'Head_HR@{x}')) factorized_metrics.append( tf.keras.metrics.TopKCategoricalAccuracy(k=x, name=f'Tail_HR@{x}')) candidates = task.MultiQueryStreaming( k=256).index_from_dataset(candidates) metrics = task.MultiQueryFactorizedTopK(candidates=candidates, metrics=factorized_metrics, k=256) retrieval_task = task.MultiShotRetrievalTask(metrics=metrics, temperature=temperature) return retrieval_task
def test_density_weighted_retrieval_model(self): user_model = parametric_attention.SimpleParametricAttention( output_dimension=2, input_embedding_dimension=2, vocab_size=self.num_items, num_representations=3, max_sequence_size=self.max_seq_size) item_model = tf.keras.Sequential([ tf.keras.layers.Embedding(input_dim=self.num_items, output_dim=2) ]) model = density_smoothed_retrieval.DensityWeightedRetrievalModel( user_model, item_model, task.MultiShotRetrievalTask(), self.num_items) self.assertIsInstance(model, tf.keras.Model) model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1)) def train_fn(train_dataset): return model.fit(train_dataset.batch(2), epochs=1) training_history = model.iterative_training(train_fn, self.train_dataset, self.item_dataset, self.item_count_weights) self.assertLen(training_history.history['loss'], 1)
def test_retrieval_task(self): user_embeddings = tf.convert_to_tensor( np.arange(12).reshape(2, 3, 2), dtype=tf.float32) item_embeddings = tf.convert_to_tensor( np.array([[0.0, 0.1], [0.2, 0.0]]), dtype=tf.float32) loss = task.MultiShotRetrievalTask()(user_embeddings, item_embeddings) self.assertAlmostEqual(loss.numpy(), 1.1955092)