예제 #1
0
    class Config(BasePairwiseModel.Config):
        """
        Attributes:
            encode_relations (bool): if `false`, return the concatenation of the two
                representations; if `true`, also concatenate their pairwise absolute
                difference and pairwise elementwise product (à la arXiv:1705.02364).
                Default: `true`.
            tied_representation: whether to use the same representation, with
              tied weights, for all the input subrepresentations. Default: `true`.
        """

        class ModelInput(BasePairwiseModel.Config.ModelInput):
            tokens1: TokenTensorizer.Config = TokenTensorizer.Config(column="text1")
            tokens2: TokenTensorizer.Config = TokenTensorizer.Config(column="text2")
            labels: LabelTensorizer.Config = LabelTensorizer.Config()
            # for metric reporter
            raw_text: JoinStringTensorizer.Config = JoinStringTensorizer.Config(
                columns=["text1", "text2"]
            )

        inputs: ModelInput = ModelInput()
        embedding: WordEmbedding.Config = WordEmbedding.Config()
        representation: Union[
            BiLSTMDocAttention.Config, DocNNRepresentation.Config
        ] = BiLSTMDocAttention.Config()
        shared_representations: bool = True
예제 #2
0
 class Config(ConfigBase):
     representation: Union[
         PureDocAttention.Config, BiLSTMDocAttention.Config,
         DocNNRepresentation.Config, ] = BiLSTMDocAttention.Config()
     decoder: MLPDecoder.Config = MLPDecoder.Config()
     output_layer: ClassificationOutputLayer.Config = (
         ClassificationOutputLayer.Config())
예제 #3
0
    class Config(Model.Config):
        class ModelInput(Model.Config.ModelInput):
            tokens: TokenTensorizer.Config = TokenTensorizer.Config()
            labels: LabelTensorizer.Config = LabelTensorizer.Config()

        inputs: ModelInput = ModelInput()
        embedding: WordEmbedding.Config = WordEmbedding.Config()
        representation: Union[
            PureDocAttention.Config, BiLSTMDocAttention.Config,
            DocNNRepresentation.Config, ] = BiLSTMDocAttention.Config()
        decoder: MLPDecoder.Config = MLPDecoder.Config()
        output_layer: ClassificationOutputLayer.Config = (
            ClassificationOutputLayer.Config())
예제 #4
0
 def _create_dummy_model(self):
     return create_model(
         DocModel_Deprecated.Config(
             representation=BiLSTMDocAttention.Config(
                 save_path=self.representation_path),
             decoder=MLPDecoder.Config(save_path=self.decoder_path),
         ),
         FeatureConfig(
             word_feat=WordEmbedding.Config(
                 embed_dim=300, save_path=self.word_embedding_path),
             save_path=self.embedding_path,
         ),
         self._create_dummy_meta_data(),
     )
예제 #5
0
    def test_load_save(self):
        text_field_meta = FieldMeta()
        text_field_meta.vocab = VocabStub()
        text_field_meta.vocab_size = 4
        text_field_meta.unk_token_idx = 1
        text_field_meta.pad_token_idx = 0
        text_field_meta.pretrained_embeds_weight = None
        label_meta = FieldMeta()
        label_meta.vocab = VocabStub()
        label_meta.vocab_size = 3
        metadata = CommonMetadata()
        metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta}
        metadata.target = label_meta

        saved_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    save_path=self.representation_path),
                decoder=MLPDecoder.Config(save_path=self.decoder_path),
            ),
            FeatureConfig(save_path=self.embedding_path),
            metadata,
        )
        saved_model.save_modules()

        loaded_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    load_path=self.representation_path),
                decoder=MLPDecoder.Config(load_path=self.decoder_path),
            ),
            FeatureConfig(load_path=self.embedding_path),
            metadata,
        )

        random_model = create_model(
            DocModel.Config(representation=BiLSTMDocAttention.Config(),
                            decoder=MLPDecoder.Config()),
            FeatureConfig(),
            metadata,
        )

        # Loaded and saved modules should be equal. Neither should be equal to
        # a randomly initialised model.

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.embedding.parameters(),
                loaded_model.embedding.parameters(),
                random_model.embedding.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.representation.parameters(),
                loaded_model.representation.parameters(),
                random_model.representation.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.decoder.parameters(),
                loaded_model.decoder.parameters(),
                random_model.decoder.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))