class Config(BasePairwiseModel.Config): """ Attributes: encode_relations (bool): if `false`, return the concatenation of the two representations; if `true`, also concatenate their pairwise absolute difference and pairwise elementwise product (à la arXiv:1705.02364). Default: `true`. tied_representation: whether to use the same representation, with tied weights, for all the input subrepresentations. Default: `true`. """ class ModelInput(BasePairwiseModel.Config.ModelInput): tokens1: TokenTensorizer.Config = TokenTensorizer.Config(column="text1") tokens2: TokenTensorizer.Config = TokenTensorizer.Config(column="text2") labels: LabelTensorizer.Config = LabelTensorizer.Config() # for metric reporter raw_text: JoinStringTensorizer.Config = JoinStringTensorizer.Config( columns=["text1", "text2"] ) inputs: ModelInput = ModelInput() embedding: WordEmbedding.Config = WordEmbedding.Config() representation: Union[ BiLSTMDocAttention.Config, DocNNRepresentation.Config ] = BiLSTMDocAttention.Config() shared_representations: bool = True
class Config(ConfigBase): representation: Union[ PureDocAttention.Config, BiLSTMDocAttention.Config, DocNNRepresentation.Config, ] = BiLSTMDocAttention.Config() decoder: MLPDecoder.Config = MLPDecoder.Config() output_layer: ClassificationOutputLayer.Config = ( ClassificationOutputLayer.Config())
class Config(Model.Config): class ModelInput(Model.Config.ModelInput): tokens: TokenTensorizer.Config = TokenTensorizer.Config() labels: LabelTensorizer.Config = LabelTensorizer.Config() inputs: ModelInput = ModelInput() embedding: WordEmbedding.Config = WordEmbedding.Config() representation: Union[ PureDocAttention.Config, BiLSTMDocAttention.Config, DocNNRepresentation.Config, ] = BiLSTMDocAttention.Config() decoder: MLPDecoder.Config = MLPDecoder.Config() output_layer: ClassificationOutputLayer.Config = ( ClassificationOutputLayer.Config())
def _create_dummy_model(self): return create_model( DocModel_Deprecated.Config( representation=BiLSTMDocAttention.Config( save_path=self.representation_path), decoder=MLPDecoder.Config(save_path=self.decoder_path), ), FeatureConfig( word_feat=WordEmbedding.Config( embed_dim=300, save_path=self.word_embedding_path), save_path=self.embedding_path, ), self._create_dummy_meta_data(), )
def test_load_save(self): text_field_meta = FieldMeta() text_field_meta.vocab = VocabStub() text_field_meta.vocab_size = 4 text_field_meta.unk_token_idx = 1 text_field_meta.pad_token_idx = 0 text_field_meta.pretrained_embeds_weight = None label_meta = FieldMeta() label_meta.vocab = VocabStub() label_meta.vocab_size = 3 metadata = CommonMetadata() metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta} metadata.target = label_meta saved_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( save_path=self.representation_path), decoder=MLPDecoder.Config(save_path=self.decoder_path), ), FeatureConfig(save_path=self.embedding_path), metadata, ) saved_model.save_modules() loaded_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( load_path=self.representation_path), decoder=MLPDecoder.Config(load_path=self.decoder_path), ), FeatureConfig(load_path=self.embedding_path), metadata, ) random_model = create_model( DocModel.Config(representation=BiLSTMDocAttention.Config(), decoder=MLPDecoder.Config()), FeatureConfig(), metadata, ) # Loaded and saved modules should be equal. Neither should be equal to # a randomly initialised model. for p1, p2, p3 in itertools.zip_longest( saved_model.embedding.parameters(), loaded_model.embedding.parameters(), random_model.embedding.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.representation.parameters(), loaded_model.representation.parameters(), random_model.representation.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.decoder.parameters(), loaded_model.decoder.parameters(), random_model.decoder.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2))