Exemplo n.º 1
0
    def test_density_weighted_retrieval_model(self):

        user_model = parametric_attention.SimpleParametricAttention(
            output_dimension=2,
            input_embedding_dimension=2,
            vocab_size=self.num_items,
            num_representations=3,
            max_sequence_size=self.max_seq_size)
        item_model = tf.keras.Sequential([
            tf.keras.layers.Embedding(input_dim=self.num_items, output_dim=2)
        ])

        model = density_smoothed_retrieval.DensityWeightedRetrievalModel(
            user_model, item_model, task.MultiShotRetrievalTask(),
            self.num_items)

        self.assertIsInstance(model, tf.keras.Model)

        model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))

        def train_fn(train_dataset):
            return model.fit(train_dataset.batch(2), epochs=1)

        training_history = model.iterative_training(train_fn,
                                                    self.train_dataset,
                                                    self.item_dataset,
                                                    self.item_count_weights)

        self.assertLen(training_history.history['loss'], 1)
Exemplo n.º 2
0
    def test_parametric_attention_model_with_multiple_representations(self):

        model = parametric_attention.SimpleParametricAttention(
            output_dimension=2,
            input_embedding_dimension=2,
            vocab_size=10,
            num_representations=3,
            max_sequence_size=20)

        input_batch = tf.convert_to_tensor(
            np.random.randint(low=0, high=10, size=(10, 20)))
        output = model(input_batch)

        self.assertIsInstance(model, tf.keras.Model)
        self.assertSequenceEqual(output.numpy().shape, [10, 3, 2])