def _get_seq_metadata(self, num_doc_classes, num_word_classes): labels = [] if num_doc_classes: vocab = Vocab(Counter()) vocab.itos = ["C_{}".format(i) for i in range(num_doc_classes)] label_meta = FieldMeta() label_meta.vocab_size = num_doc_classes label_meta.vocab = vocab labels.append(label_meta) w_vocab = Vocab(Counter()) w_vocab.itos = W_VOCAB seq_feat_meta = FieldMeta() seq_feat_meta.unk_token_idx = UNK_IDX seq_feat_meta.pad_token_idx = PAD_IDX seq_feat_meta.vocab_size = W_VOCAB_SIZE seq_feat_meta.vocab = w_vocab seq_feat_meta.vocab_export_name = "seq_tokens_vals" seq_feat_meta.pretrained_embeds_weight = None seq_feat_meta.dummy_model_input = SeqFeatureField.dummy_model_input meta = CommonMetadata() meta.features = {DatasetFieldName.TEXT_FIELD: seq_feat_meta} meta.target = labels if len(labels) == 1: [meta.target] = meta.target meta.label_names = [label.vocab.itos for label in labels] meta.feature_itos_map = { f.vocab_export_name: f.vocab.itos for _, f in meta.features.items() } return meta
def mock_metadata(): meta = CommonMetadata field_meta = FieldMeta() field_meta.vocab = VocabStub() field_meta.vocab_size = 10 field_meta.pretrained_embeds_weight = None field_meta.unk_token_idx = 0 field_meta.pad_token_idx = 1 meta.features = {"word_feat": field_meta, "dict_feat": field_meta} meta.target = field_meta return meta
def _create_dummy_meta_data(self): text_field_meta = FieldMeta() text_field_meta.vocab = VocabStub() text_field_meta.vocab_size = 4 text_field_meta.unk_token_idx = 1 text_field_meta.pad_token_idx = 0 text_field_meta.pretrained_embeds_weight = None label_meta = FieldMeta() label_meta.vocab = VocabStub() label_meta.vocab_size = 3 metadata = CommonMetadata() metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta} metadata.target = label_meta return metadata
def test_load_save(self): text_field_meta = FieldMeta() text_field_meta.vocab = VocabStub() text_field_meta.vocab_size = 4 text_field_meta.unk_token_idx = 1 text_field_meta.pad_token_idx = 0 text_field_meta.pretrained_embeds_weight = None label_meta = FieldMeta() label_meta.vocab = VocabStub() label_meta.vocab_size = 3 metadata = CommonMetadata() metadata.features = {DatasetFieldName.TEXT_FIELD: text_field_meta} metadata.target = label_meta saved_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( save_path=self.representation_path), decoder=MLPDecoder.Config(save_path=self.decoder_path), ), FeatureConfig(save_path=self.embedding_path), metadata, ) saved_model.save_modules() loaded_model = create_model( DocModel.Config( representation=BiLSTMDocAttention.Config( load_path=self.representation_path), decoder=MLPDecoder.Config(load_path=self.decoder_path), ), FeatureConfig(load_path=self.embedding_path), metadata, ) random_model = create_model( DocModel.Config(representation=BiLSTMDocAttention.Config(), decoder=MLPDecoder.Config()), FeatureConfig(), metadata, ) # Loaded and saved modules should be equal. Neither should be equal to # a randomly initialised model. for p1, p2, p3 in itertools.zip_longest( saved_model.embedding.parameters(), loaded_model.embedding.parameters(), random_model.embedding.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.representation.parameters(), loaded_model.representation.parameters(), random_model.representation.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2)) for p1, p2, p3 in itertools.zip_longest( saved_model.decoder.parameters(), loaded_model.decoder.parameters(), random_model.decoder.parameters(), ): self.assertTrue(p1.equal(p2)) self.assertFalse(p3.equal(p1)) self.assertFalse(p3.equal(p2))
def _get_metadata(self, num_doc_classes, num_word_classes): labels = [] if num_doc_classes: vocab = Vocab(Counter()) vocab.itos = ["C_{}".format(i) for i in range(num_doc_classes)] label_meta = FieldMeta() label_meta.vocab_size = num_doc_classes label_meta.vocab = vocab labels.append(label_meta) if num_word_classes: vocab = Vocab(Counter()) vocab.itos = ["W_{}".format(i) for i in range(num_word_classes)] label_meta = FieldMeta() label_meta.vocab_size = num_word_classes label_meta.vocab = vocab label_meta.pad_token_idx = 0 labels.append(label_meta) w_vocab = Vocab(Counter()) dict_vocab = Vocab(Counter()) c_vocab = Vocab(Counter()) d_vocab = Vocab(Counter()) w_vocab.itos = W_VOCAB dict_vocab.itos = DICT_VOCAB c_vocab.itos = CHAR_VOCAB d_vocab.itos = [] text_feat_meta = FieldMeta() text_feat_meta.unk_token_idx = UNK_IDX text_feat_meta.pad_token_idx = PAD_IDX text_feat_meta.vocab_size = W_VOCAB_SIZE text_feat_meta.vocab = w_vocab text_feat_meta.vocab_export_name = "tokens_vals" text_feat_meta.pretrained_embeds_weight = None text_feat_meta.dummy_model_input = TextFeatureField.dummy_model_input dict_feat_meta = FieldMeta() dict_feat_meta.vocab_size = DICT_VOCAB_SIZE dict_feat_meta.vocab = dict_vocab dict_feat_meta.vocab_export_name = "dict_vals" dict_feat_meta.pretrained_embeds_weight = None dict_feat_meta.dummy_model_input = DictFeatureField.dummy_model_input char_feat_meta = FieldMeta() char_feat_meta.vocab_size = CHAR_VOCAB_SIZE char_feat_meta.vocab = c_vocab char_feat_meta.vocab_export_name = "char_vals" char_feat_meta.pretrained_embeds_weight = None char_feat_meta.dummy_model_input = CharFeatureField.dummy_model_input dense_feat_meta = FieldMeta() dense_feat_meta.vocab_size = 0 dense_feat_meta.vocab = d_vocab dense_feat_meta.vocab_export_name = "dense_vals" dense_feat_meta.pretrained_embeds_weight = None # ugh, dims are fixed dense_feat_meta.dummy_model_input = torch.tensor( [[1.0] * DENSE_FEATURE_DIM, [1.0] * DENSE_FEATURE_DIM], dtype=torch.float, device="cpu", ) meta = CommonMetadata() meta.features = { DatasetFieldName.TEXT_FIELD: text_feat_meta, DatasetFieldName.DICT_FIELD: dict_feat_meta, DatasetFieldName.CHAR_FIELD: char_feat_meta, DatasetFieldName.DENSE_FIELD: dense_feat_meta, } meta.target = labels if len(labels) == 1: [meta.target] = meta.target meta.label_names = [label.vocab.itos for label in labels] meta.feature_itos_map = { f.vocab_export_name: f.vocab.itos for _, f in meta.features.items() } return meta