コード例 #1
0
ファイル: train.py プロジェクト: lucifer2288/google-research
    def setup_metrics(item_dataset, item_model):
        def item_map(batched_items):
            return tf.squeeze(item_model(tf.expand_dims(batched_items,
                                                        axis=1)))

        candidates = item_dataset.batch(500).map(item_map)
        metrics_k = map(int, FLAGS.metrics_k)
        factorized_metrics = []
        for x in metrics_k:
            factorized_metrics.append(
                tf.keras.metrics.TopKCategoricalAccuracy(k=x, name=f'HR@{x}'))
            factorized_metrics.append(
                tf.keras.metrics.TopKCategoricalAccuracy(k=x,
                                                         name=f'Head_HR@{x}'))
            factorized_metrics.append(
                tf.keras.metrics.TopKCategoricalAccuracy(k=x,
                                                         name=f'Tail_HR@{x}'))

        candidates = task.MultiQueryStreaming(
            k=256).index_from_dataset(candidates)
        metrics = task.MultiQueryFactorizedTopK(candidates=candidates,
                                                metrics=factorized_metrics,
                                                k=256)
        retrieval_task = task.MultiShotRetrievalTask(metrics=metrics,
                                                     temperature=temperature)

        return retrieval_task
コード例 #2
0
    def test_density_weighted_retrieval_model(self):

        user_model = parametric_attention.SimpleParametricAttention(
            output_dimension=2,
            input_embedding_dimension=2,
            vocab_size=self.num_items,
            num_representations=3,
            max_sequence_size=self.max_seq_size)
        item_model = tf.keras.Sequential([
            tf.keras.layers.Embedding(input_dim=self.num_items, output_dim=2)
        ])

        model = density_smoothed_retrieval.DensityWeightedRetrievalModel(
            user_model, item_model, task.MultiShotRetrievalTask(),
            self.num_items)

        self.assertIsInstance(model, tf.keras.Model)

        model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))

        def train_fn(train_dataset):
            return model.fit(train_dataset.batch(2), epochs=1)

        training_history = model.iterative_training(train_fn,
                                                    self.train_dataset,
                                                    self.item_dataset,
                                                    self.item_count_weights)

        self.assertLen(training_history.history['loss'], 1)
コード例 #3
0
  def test_retrieval_task(self):

    user_embeddings = tf.convert_to_tensor(
        np.arange(12).reshape(2, 3, 2), dtype=tf.float32)
    item_embeddings = tf.convert_to_tensor(
        np.array([[0.0, 0.1], [0.2, 0.0]]), dtype=tf.float32)
    loss = task.MultiShotRetrievalTask()(user_embeddings, item_embeddings)

    self.assertAlmostEqual(loss.numpy(), 1.1955092)