def _get_seq_metadata(self, num_doc_classes, num_word_classes):
        labels = []
        if num_doc_classes:
            vocab = Vocab(Counter())
            vocab.itos = ["C_{}".format(i) for i in range(num_doc_classes)]
            label_meta = FieldMeta()
            label_meta.vocab_size = num_doc_classes
            label_meta.vocab = vocab
            labels.append(label_meta)

        w_vocab = Vocab(Counter())
        w_vocab.itos = W_VOCAB

        seq_feat_meta = FieldMeta()
        seq_feat_meta.unk_token_idx = UNK_IDX
        seq_feat_meta.pad_token_idx = PAD_IDX
        seq_feat_meta.vocab_size = W_VOCAB_SIZE
        seq_feat_meta.vocab = w_vocab
        seq_feat_meta.vocab_export_name = "seq_tokens_vals"
        seq_feat_meta.pretrained_embeds_weight = None
        seq_feat_meta.dummy_model_input = SeqFeatureField.dummy_model_input

        meta = CommonMetadata()
        meta.features = {DatasetFieldName.TEXT_FIELD: seq_feat_meta}
        meta.target = labels
        if len(labels) == 1:
            [meta.target] = meta.target
        meta.label_names = [label.vocab.itos for label in labels]
        meta.feature_itos_map = {
            f.vocab_export_name: f.vocab.itos
            for _, f in meta.features.items()
        }
        return meta
Esempio n. 2
0
 def _create_dummy_meta_data(self):
     text_field_meta = FieldMeta()
     text_field_meta.vocab = VocabStub()
     text_field_meta.vocab_size = 4
     text_field_meta.unk_token_idx = 1
     text_field_meta.pad_token_idx = 0
     text_field_meta.pretrained_embeds_weight = None
     label_meta = FieldMeta()
     label_meta.vocab = VocabStub()
     label_meta.vocab_size = 3
     metadata = CommonMetadata()
     metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta}
     metadata.target = label_meta
     return metadata
Esempio n. 3
0
    def test_load_save(self):
        text_field_meta = FieldMeta()
        text_field_meta.vocab = VocabStub()
        text_field_meta.vocab_size = 4
        text_field_meta.unk_token_idx = 1
        text_field_meta.pad_token_idx = 0
        text_field_meta.pretrained_embeds_weight = None
        label_meta = FieldMeta()
        label_meta.vocab = VocabStub()
        label_meta.vocab_size = 3
        metadata = CommonMetadata()
        metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta}
        metadata.target = label_meta

        saved_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    save_path=self.representation_path),
                decoder=MLPDecoder.Config(save_path=self.decoder_path),
            ),
            FeatureConfig(save_path=self.embedding_path),
            metadata,
        )
        saved_model.save_modules()

        loaded_model = create_model(
            DocModel.Config(
                representation=BiLSTMDocAttention.Config(
                    load_path=self.representation_path),
                decoder=MLPDecoder.Config(load_path=self.decoder_path),
            ),
            FeatureConfig(load_path=self.embedding_path),
            metadata,
        )

        random_model = create_model(
            DocModel.Config(representation=BiLSTMDocAttention.Config(),
                            decoder=MLPDecoder.Config()),
            FeatureConfig(),
            metadata,
        )

        # Loaded and saved modules should be equal. Neither should be equal to
        # a randomly initialised model.

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.embedding.parameters(),
                loaded_model.embedding.parameters(),
                random_model.embedding.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.representation.parameters(),
                loaded_model.representation.parameters(),
                random_model.representation.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))

        for p1, p2, p3 in itertools.zip_longest(
                saved_model.decoder.parameters(),
                loaded_model.decoder.parameters(),
                random_model.decoder.parameters(),
        ):
            self.assertTrue(p1.equal(p2))
            self.assertFalse(p3.equal(p1))
            self.assertFalse(p3.equal(p2))
    def _get_metadata(self, num_doc_classes, num_word_classes):
        labels = []
        if num_doc_classes:
            vocab = Vocab(Counter())
            vocab.itos = ["C_{}".format(i) for i in range(num_doc_classes)]
            label_meta = FieldMeta()
            label_meta.vocab_size = num_doc_classes
            label_meta.vocab = vocab
            labels.append(label_meta)

        if num_word_classes:
            vocab = Vocab(Counter())
            vocab.itos = ["W_{}".format(i) for i in range(num_word_classes)]
            label_meta = FieldMeta()
            label_meta.vocab_size = num_word_classes
            label_meta.vocab = vocab
            label_meta.pad_token_idx = 0
            labels.append(label_meta)

        w_vocab = Vocab(Counter())
        dict_vocab = Vocab(Counter())
        c_vocab = Vocab(Counter())
        d_vocab = Vocab(Counter())
        w_vocab.itos = W_VOCAB
        dict_vocab.itos = DICT_VOCAB
        c_vocab.itos = CHAR_VOCAB
        d_vocab.itos = []

        text_feat_meta = FieldMeta()
        text_feat_meta.unk_token_idx = UNK_IDX
        text_feat_meta.pad_token_idx = PAD_IDX
        text_feat_meta.vocab_size = W_VOCAB_SIZE
        text_feat_meta.vocab = w_vocab
        text_feat_meta.vocab_export_name = "tokens_vals"
        text_feat_meta.pretrained_embeds_weight = None
        text_feat_meta.dummy_model_input = TextFeatureField.dummy_model_input

        dict_feat_meta = FieldMeta()
        dict_feat_meta.vocab_size = DICT_VOCAB_SIZE
        dict_feat_meta.vocab = dict_vocab
        dict_feat_meta.vocab_export_name = "dict_vals"
        dict_feat_meta.pretrained_embeds_weight = None
        dict_feat_meta.dummy_model_input = DictFeatureField.dummy_model_input

        char_feat_meta = FieldMeta()
        char_feat_meta.vocab_size = CHAR_VOCAB_SIZE
        char_feat_meta.vocab = c_vocab
        char_feat_meta.vocab_export_name = "char_vals"
        char_feat_meta.pretrained_embeds_weight = None
        char_feat_meta.dummy_model_input = CharFeatureField.dummy_model_input

        dense_feat_meta = FieldMeta()
        dense_feat_meta.vocab_size = 0
        dense_feat_meta.vocab = d_vocab
        dense_feat_meta.vocab_export_name = "dense_vals"
        dense_feat_meta.pretrained_embeds_weight = None
        # ugh, dims are fixed
        dense_feat_meta.dummy_model_input = torch.tensor(
            [[1.0] * DENSE_FEATURE_DIM, [1.0] * DENSE_FEATURE_DIM],
            dtype=torch.float,
            device="cpu",
        )

        meta = CommonMetadata()
        meta.features = {
            DatasetFieldName.TEXT_FIELD: text_feat_meta,
            DatasetFieldName.DICT_FIELD: dict_feat_meta,
            DatasetFieldName.CHAR_FIELD: char_feat_meta,
            DatasetFieldName.DENSE_FIELD: dense_feat_meta,
        }
        meta.target = labels
        if len(labels) == 1:
            [meta.target] = meta.target
        meta.label_names = [label.vocab.itos for label in labels]
        meta.feature_itos_map = {
            f.vocab_export_name: f.vocab.itos
            for _, f in meta.features.items()
        }
        return meta