def example_config(cls): return cls.Config( labels=[DocLabelConfig(), WordLabelConfig()], model=BaggingDocEnsemble_Deprecated.Config( models=[DocModel_Deprecated.Config()] ), )
def test_freeze_all_embedding(self): model = create_model( DocModel_Deprecated.Config(), FeatureConfig(freeze=True), metadata=mock_metadata(), ) for param in model.embedding.parameters(): self.assertFalse(param.requires_grad)
class Config(Task_Deprecated.Config): model: DocModel_Deprecated.Config = DocModel_Deprecated.Config() trainer: Trainer.Config = Trainer.Config() features: DocClassification.ModelInputConfig = ( DocClassification.ModelInputConfig()) labels: DocClassification.TargetConfig = DocClassification.TargetConfig( ) data_handler: DocClassificationDataHandler.Config = ( DocClassificationDataHandler.Config()) metric_reporter: ClassificationMetricReporter.Config = ( ClassificationMetricReporter.Config()) exporter: Optional[DenseFeatureExporter.Config] = None
def _create_dummy_model(self): return create_model( DocModel_Deprecated.Config( representation=BiLSTMDocAttention.Config( save_path=self.representation_path), decoder=MLPDecoder.Config(save_path=self.decoder_path), ), FeatureConfig( word_feat=WordEmbedding.Config( embed_dim=300, save_path=self.word_embedding_path), save_path=self.embedding_path, ), self._create_dummy_meta_data(), )
def test_freeze_word_embedding(self): model = create_model( DocModel_Deprecated.Config(), FeatureConfig( word_feat=WordFeatConfig(freeze=True, mlp_layer_dims=[4]), dict_feat=DictFeatConfig(), ), metadata=mock_metadata(), ) # word embedding for param in model.embedding[0].word_embedding.parameters(): self.assertFalse(param.requires_grad) for param in model.embedding[0].mlp.parameters(): self.assertTrue(param.requires_grad) # dict feat embedding for param in model.embedding[1].parameters(): self.assertTrue(param.requires_grad)
def test_load_save(self): text_field_meta = FieldMeta() text_field_meta.vocab = VocabStub() text_field_meta.vocab_size = 4 text_field_meta.unk_token_idx = 1 text_field_meta.pad_token_idx = 0 text_field_meta.pretrained_embeds_weight = None label_meta = FieldMeta() label_meta.vocab = VocabStub() label_meta.vocab_size = 3 metadata = CommonMetadata() metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta} metadata.target = label_meta saved_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( save_path=self.representation_path), decoder=MLPDecoder.Config(save_path=self.decoder_path), ), FeatureConfig(save_path=self.embedding_path), metadata, ) saved_model.save_modules() loaded_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( load_path=self.representation_path), decoder=MLPDecoder.Config(load_path=self.decoder_path), ), FeatureConfig(load_path=self.embedding_path), metadata, ) random_model = create_model( DocModel.Config(representation=BiLSTMDocAttention.Config(), decoder=MLPDecoder.Config()), FeatureConfig(), metadata, ) # Loaded and saved modules should be equal. Neither should be equal to # a randomly initialised model. for p1, p2, p3 in itertools.zip_longest( saved_model.embedding.parameters(), loaded_model.embedding.parameters(), random_model.embedding.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.representation.parameters(), loaded_model.representation.parameters(), random_model.representation.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.decoder.parameters(), loaded_model.decoder.parameters(), random_model.decoder.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2))