Ejemplo n.º 1
0
 def example_config(cls):
     return cls.Config(
         labels=[DocLabelConfig(), WordLabelConfig()],
         model=BaggingDocEnsemble_Deprecated.Config(
             models=[DocModel_Deprecated.Config()]
         ),
     )
Ejemplo n.º 2
0
 def test_freeze_all_embedding(self):
     model = create_model(
         DocModel_Deprecated.Config(),
         FeatureConfig(freeze=True),
         metadata=mock_metadata(),
     )
     for param in model.embedding.parameters():
         self.assertFalse(param.requires_grad)
Ejemplo n.º 3
0
 class Config(Task_Deprecated.Config):
     model: DocModel_Deprecated.Config = DocModel_Deprecated.Config()
     trainer: Trainer.Config = Trainer.Config()
     features: DocClassification.ModelInputConfig = (
         DocClassification.ModelInputConfig())
     labels: DocClassification.TargetConfig = DocClassification.TargetConfig(
     )
     data_handler: DocClassificationDataHandler.Config = (
         DocClassificationDataHandler.Config())
     metric_reporter: ClassificationMetricReporter.Config = (
         ClassificationMetricReporter.Config())
     exporter: Optional[DenseFeatureExporter.Config] = None
Ejemplo n.º 4
0
 def _create_dummy_model(self):
     return create_model(
         DocModel_Deprecated.Config(
             representation=BiLSTMDocAttention.Config(
                 save_path=self.representation_path),
             decoder=MLPDecoder.Config(save_path=self.decoder_path),
         ),
         FeatureConfig(
             word_feat=WordEmbedding.Config(
                 embed_dim=300, save_path=self.word_embedding_path),
             save_path=self.embedding_path,
         ),
         self._create_dummy_meta_data(),
     )
Ejemplo n.º 5
0
    def test_freeze_word_embedding(self):
        model = create_model(
            DocModel_Deprecated.Config(),
            FeatureConfig(
                word_feat=WordFeatConfig(freeze=True, mlp_layer_dims=[4]),
                dict_feat=DictFeatConfig(),
            ),
            metadata=mock_metadata(),
        )
        # word embedding
        for param in model.embedding[0].word_embedding.parameters():
            self.assertFalse(param.requires_grad)
        for param in model.embedding[0].mlp.parameters():
            self.assertTrue(param.requires_grad)

        # dict feat embedding
        for param in model.embedding[1].parameters():
            self.assertTrue(param.requires_grad)
Ejemplo n.º 6
0
    def test_load_save(self):
        text_field_meta = FieldMeta()
        text_field_meta.vocab = VocabStub()
        text_field_meta.vocab_size = 4
        text_field_meta.unk_token_idx = 1
        text_field_meta.pad_token_idx = 0
        text_field_meta.pretrained_embeds_weight = None
        label_meta = FieldMeta()
        label_meta.vocab = VocabStub()
        label_meta.vocab_size = 3
        metadata = CommonMetadata()
        metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta}
        metadata.target = label_meta

        saved_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    save_path=self.representation_path),
                decoder=MLPDecoder.Config(save_path=self.decoder_path),
            ),
            FeatureConfig(save_path=self.embedding_path),
            metadata,
        )
        saved_model.save_modules()

        loaded_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    load_path=self.representation_path),
                decoder=MLPDecoder.Config(load_path=self.decoder_path),
            ),
            FeatureConfig(load_path=self.embedding_path),
            metadata,
        )

        random_model = create_model(
            DocModel.Config(representation=BiLSTMDocAttention.Config(),
                            decoder=MLPDecoder.Config()),
            FeatureConfig(),
            metadata,
        )

        # Loaded and saved modules should be equal. Neither should be equal to
        # a randomly initialised model.

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.embedding.parameters(),
                loaded_model.embedding.parameters(),
                random_model.embedding.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.representation.parameters(),
                loaded_model.representation.parameters(),
                random_model.representation.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.decoder.parameters(),
                loaded_model.decoder.parameters(),
                random_model.decoder.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))