Ejemplo n.º 1
0
 class Config(Optimizer.Config):
     optimizer: Union[
         SGD.Config,
         Adam.Config,
         AdamW.Config,
         Adagrad.Config,
         RAdam.Config,
         Lamb.Config,
     ] = SGD.Config()
     start: int = 10
     frequency: int = 5
     swa_learning_rate: Optional[float] = 0.05
Ejemplo n.º 2
0
    def test_user_embedding_updates(self):
        """Verify the user embeddings learn independently."""
        task_class = DocumentClassificationTask
        pytext_config = self._get_pytext_config(
            test_file_name=TestFileName.TEST_PERSONALIZATION_SINGLE_USER_TSV,
            task_class=task_class,
            model_class=PersonalizedDocModel,
        )
        # SGD changes only the user embeddings which have non-zero gradients.
        pytext_config.task.trainer.optimizer = SGD.Config()
        p13n_task = task_class.from_config(pytext_config.task)

        orig_user_embedding_weights = copy.deepcopy(
            p13n_task.model.user_embedding.weight)
        p13n_model, _ = p13n_task.train(pytext_config)
        trained_user_embedding_weights = p13n_model.user_embedding.weight

        self.assertEqual(
            len(orig_user_embedding_weights),
            2,
            "There should be 2 user embeddings, including the unknown user.",
        )

        self.assertEqual(
            len(orig_user_embedding_weights),
            len(trained_user_embedding_weights),
            "Length of user embeddings should not be changed by the training.",
        )

        # Verify that the training changes only 1 user embedding in the p13n_model.
        self.assertTrue(
            torch.equal(orig_user_embedding_weights[0],
                        trained_user_embedding_weights[0]),
            "Unknown user embedding should not change.",
        )
        self.assertFalse(
            torch.equal(orig_user_embedding_weights[1],
                        trained_user_embedding_weights[1]),
            "The only user embedding should change.",
        )
Ejemplo n.º 3
0
 class Config(Optimizer.Config):
     optimizer: Union[SGD.Config, Adam.Config] = SGD.Config()
     start: int = 10
     frequency: int = 5
     swa_learning_rate: float = 0.05