예제 #1
0
    def test_payload(self):
        self.args.data_config.ent_embeddings.extend(
            [DottedDict({"key": "kg_emb"}), DottedDict({"key": "kg_rel"})]
        )
        emb_payload = EmbeddingPayloadMock(self.args, self.entity_symbols)
        emb_payload.linear_layers = nn.ModuleDict()
        emb_payload.linear_layers["project_embedding"] = NoopCat()
        emb_payload.position_enc = nn.ModuleDict()
        emb_payload.position_enc["alias"] = Noop()
        emb_payload.position_enc["alias_last_token"] = Noop()
        emb_payload.position_enc["alias_position_cat"] = NoopTakeFirst()

        batch_size = 3
        # Max aliases is set to 5 in the config so we need one position per alias per batch
        start_idx_pair = torch.zeros(batch_size, emb_payload.M)
        end_idx_pair = torch.zeros(batch_size, emb_payload.M)

        entity_embedding = {
            "kg_emb": torch.randn(batch_size, emb_payload.M, emb_payload.K, 5),
            "kg_rel": torch.randn(batch_size, emb_payload.M, emb_payload.K, 6),
        }
        alias_list = emb_payload(
            start_idx_pair,
            end_idx_pair,
            entity_embedding["kg_emb"],
            entity_embedding["kg_rel"],
        )
        # The embeddings have been concatenated together
        assert torch.isclose(
            alias_list,
            torch.cat([entity_embedding["kg_emb"], entity_embedding["kg_rel"]], dim=3),
        ).all()
예제 #2
0
    def test_kg_norm(self):
        emb_sizes = {"kg_emb": 5, "kg_rel": 6}
        sent_emb_size = 10
        self.args.data_config.ent_embeddings.extend(
            [DottedDict({"key": "kg_emb"}),
             DottedDict({"key": "kg_rel"})])
        emb_combiner = EmbCombinerProj(self.args, emb_sizes, sent_emb_size,
                                       self.word_symbols, self.entity_symbols)
        emb_combiner.linear_layers = nn.ModuleDict()
        emb_combiner.linear_layers['project_embedding'] = NoopCat()
        emb_combiner.position_enc = nn.ModuleDict()
        emb_combiner.position_enc['alias'] = Noop()
        emb_combiner.position_enc['alias_last_token'] = Noop()
        emb_combiner.position_enc['alias_position_cat'] = NoopTakeFirst()
        batch_size = 3
        num_words = 5
        sent_embedding = DottedDict(tensor=torch.randn(batch_size, num_words,
                                                       sent_emb_size),
                                    downstream_mask=None,
                                    mask=None,
                                    key="sent_emb",
                                    dim=10)
        # Max aliases is set to 5 in the config so we need one position per alias per batch
        alias_idx_pair_sent = [
            torch.tensor([[0] * emb_combiner.M, [0] * emb_combiner.M]),
            torch.tensor([[0] * emb_combiner.M, [0] * emb_combiner.M])
        ]

        entity_embedding = [
            DottedDict(tensor=torch.randn(batch_size, emb_combiner.M,
                                          emb_combiner.K, 5),
                       pos_in_sent=get_pos_in_sent(),
                       alias_indices=None,
                       mask=None,
                       normalize=True,
                       key="kg_emb",
                       dim=5),
            DottedDict(tensor=torch.randn(batch_size, emb_combiner.M,
                                          emb_combiner.K, 6),
                       pos_in_sent=get_pos_in_sent(),
                       alias_indices=None,
                       mask=None,
                       normalize=False,
                       key="kg_rel",
                       dim=6)
        ]
        norm_1 = entity_embedding[0].tensor.norm(p=2, dim=3)
        norm_2 = entity_embedding[1].tensor.norm(p=2, dim=3)
        entity_mask = None
        _, res_package = emb_combiner(sent_embedding, alias_idx_pair_sent,
                                      entity_embedding, entity_mask)
        alias_list = res_package.tensor
        # The embeddings have been normalized and concatenated together
        assert torch.isclose(alias_list[:, :, :, :5].norm(p=2, dim=3),
                             torch.ones_like(norm_1)).all()
        assert torch.isclose(alias_list[:, :, :, 5:].norm(p=2, dim=3),
                             norm_2).all()
예제 #3
0
 def setUp(self) -> None:
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.args.data_config.ent_embeddings = [
         DottedDict(
         {
             "key": "learned1",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned2",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned3",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned4",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned5",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
     ]
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbolsSubclass()
예제 #4
0
 def setUp(self) -> None:
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     emmental.init(log_dir="test/temp_log", config=self.args)
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.args.data_config.ent_embeddings = [
         DottedDict(
             {
                 "key": "learned1",
                 "load_class": "LearnedEntityEmb",
                 "dropout1d": 0.5,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned2",
                 "dropout2d": 0.5,
                 "load_class": "LearnedEntityEmb",
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned3",
                 "load_class": "LearnedEntityEmb",
                 "freeze": True,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned4",
                 "load_class": "LearnedEntityEmb",
                 "normalize": False,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned5",
                 "load_class": "LearnedEntityEmb",
                 "cpu": True,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
     ]
     self.tokenizer = load_tokenizer()
     self.entity_symbols = EntitySymbolsSubclass()
예제 #5
0
    def forward(self, word_indices, requires_grad=None):
        if requires_grad is None:
            requires_grad = self.requires_grad
        (batch_size, seq_length) = word_indices.shape
        # num_words_with_pad_unk-1  because index starts at 0
        word_indices_pos = torch.where(
            word_indices >= 0, word_indices,
            (torch.ones_like(word_indices, dtype=torch.long) *
             (self.num_words_with_pad_unk - 1)))

        if requires_grad:
            word_vectors = self.word_embedding(word_indices_pos)
        else:
            with torch.no_grad():
                word_vectors = self.word_embedding(word_indices_pos)
        if self.use_proj:
            word_vectors = self.proj(word_vectors)
        word_vectors = self.position_words(
            word_vectors,
            torch.arange(0, word_indices.shape[1]).repeat(batch_size, 1))
        word_vectors = self.layer_norm(word_vectors)
        word_vectors = self.dropout(word_vectors)
        packed_emb = DottedDict(
            tensor=word_vectors,
            # The mask for the standard word embedding is the same as the downstream mask
            mask=self.get_downstream_mask(word_indices),
            downstream_mask=self.get_downstream_mask(word_indices),
            key=self.get_key(),
            dim=self.get_dim())
        return packed_emb
예제 #6
0
    def test_end2end_withtitle_accstep(self):
        self.args.data_config.ent_embeddings.append(
            DottedDict({
                "key": "title1",
                "load_class": "TitleEmb",
                "send_through_bert": True,
                "args": {
                    "proj": 6
                },
            }))
        # Just setting this for testing pipelines
        self.args.data_config.eval_accumulation_steps = 2
        self.args.run_config.dataset_threads = 2
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.08

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)
예제 #7
0
 def forward(self, word_indices, entity_package, batch_prepped_data,
             batch_on_the_fly_data):
     word_package = self.word_emb(word_indices)
     sent_emb = self.sent_emb(word_package)
     sent_tensor = self.project_sent(sent_emb.tensor)
     # WordPackages have two masks. One is used in the sentence embedding module (param mask) and one is used in our attention network (param downstream mask).
     # At this point, the mask we always want to use is the downstream mask (as we are beyond sentence embedding stage).
     # Hence we sent both mask and downstream mask to be the same.
     sent_emb = DottedDict(tensor=sent_tensor,
                           mask=sent_emb.downstream_mask,
                           key=sent_emb.key,
                           dim=sent_emb.dim)
     entity_embs = []
     for entity_emb in self.entity_embs.values():
         forward_package = entity_package
         forward_package.key = entity_emb.key
         emb_package = entity_emb(forward_package, batch_prepped_data,
                                  batch_on_the_fly_data, sent_emb)
         # If some embeddings do not want to be added to the payload, they will return an empty emb_package
         # This happens for the kg bias (KGIndices) class for our kg attention network
         if len(emb_package) == 0:
             continue
         emb_package.tensor = model_utils.emb_2d_dropout(
             entity_emb.training, entity_emb.mask_perc, emb_package.tensor)
         entity_embs.append(emb_package)
     return sent_emb, entity_embs
예제 #8
0
 def forward(self, word_package):
     attention_mask = word_package.mask
     word_vectors = word_package.tensor
     batch_size = word_vectors.shape[0]
     word_vectors = word_vectors.transpose(0, 1)
     out = word_vectors
     if self.requires_grad:
         for i in range(self.num_layers):
             out, weights = self.attention_modules[
                 f"stage_{i}_self_sentence"](out,
                                             key_mask=attention_mask,
                                             attn_mask=None)
             self.attention_weights[f"layer_{i}_sent"] = weights
     else:
         with torch.no_grad():
             for i in range(self.num_layers):
                 out, weights = self.attention_modules[
                     f"stage_{i}_self_sentence"](out,
                                                 key_mask=attention_mask,
                                                 attn_mask=None)
                 self.attention_weights[f"layer_{i}_sent"] = weights
     out = out.transpose(0, 1)
     assert out.shape[0] == batch_size
     emb = DottedDict(
         tensor=out,
         # Make the main mask also be the downstream mask as this is post sentence embedding
         downstream_mask=word_package.downstream_mask,
         key=self.get_key(),
         dim=self.get_dim())
     return emb
예제 #9
0
 def forward(self, alias_idx_pair_sent, word_indices, alias_indices,
             entity_indices, batch_prepped_data, batch_on_the_fly_data):
     # mask out padded candidates
     mask = entity_indices == -1
     entity_indices = torch.where(
         entity_indices >= 0, entity_indices,
         (torch.ones_like(entity_indices, dtype=torch.long) *
          (self.num_entities_with_pad_and_nocand - 1)))
     entity_package = DottedDict(tensor=entity_indices,
                                 pos_in_sent=alias_idx_pair_sent,
                                 alias_indices=alias_indices,
                                 mask=mask)
     sent_emb, entity_embs = self.emb_layer(word_indices, entity_package,
                                            batch_prepped_data,
                                            batch_on_the_fly_data)
     sent_emb, entity_embs = self.emb_combiner(
         sent_embedding=sent_emb,
         alias_idx_pair_sent=alias_idx_pair_sent,
         entity_embedding=entity_embs,
         entity_mask=mask)
     context_matrix_dict, backbone_out = self.attn_network(
         alias_idx_pair_sent, sent_emb, entity_embs, batch_prepped_data,
         batch_on_the_fly_data)
     res, final_entity_embs = self.slice_heads(
         context_matrix_dict,
         alias_idx_pair_sent=alias_idx_pair_sent,
         entity_pack=entity_package,
         sent_emb=sent_emb,
         batch_prepped=batch_prepped_data,
         raw_entity_emb=entity_embs)
     # update output dictionary with backbone out
     res[DISAMBIG].update(backbone_out[DISAMBIG])
     return res, entity_package, final_entity_embs
예제 #10
0
    def test_setup_incand(self):
        self.config["data_config"]["train_in_candidates"] = True
        self.alias_entity_table = AliasEntityTable(
            DottedDict(self.config["data_config"]), self.entity_symbols
        )

        gold_alias2entity_table = torch.tensor(
            [
                [1, 4, -1],
                [1, -1, -1],
                [4, 3, 2],
                [2, 1, 4],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [4, 3, 2],
                [-1, -1, -1],
                [-1, -1, -1],
            ]
        )
        assert torch.equal(
            gold_alias2entity_table.long(),
            self.alias_entity_table.alias2entity_table.long(),
        )
예제 #11
0
 def test_setup_notincand(self):
     self.alias_entity_table = AliasEntityTable(
         DottedDict(self.config["data_config"]), self.entity_symbols
     )
     gold_alias2entity_table = torch.tensor(
         [
             [0, 1, 4, -1],
             [0, 1, -1, -1],
             [0, 4, 3, 2],
             [0, 2, 1, 4],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [0, 4, 3, 2],
             [-1, -1, -1, -1],
             [-1, -1, -1, -1],
         ]
     )
     assert torch.equal(
         gold_alias2entity_table.long(),
         self.alias_entity_table.alias2entity_table.long(),
     )
예제 #12
0
파일: layers.py 프로젝트: syyunn/bootleg
    def forward(self, sent_emb, entity_package, entity_embs):
        batch, M, K = entity_package.tensor.shape
        # Get alias tensor and expand to be for each candidate for soft attn
        alias_word_tensor = model_utils.select_alias_word_sent(
            entity_package.pos_in_sent, sent_emb, index=0)
        alias_mask = entity_package.alias_indices == -1

        # batch x M x num_types
        batch_type_pred = self.prediction(alias_word_tensor)
        batch_type_weights = self.type_softmax(batch_type_pred)
        # batch x M x emb_size
        batch_type_embs = torch.matmul(batch_type_weights,
                                       self.type_embedding.unsqueeze(0))
        # mask out unk alias embeddings
        batch_type_embs[alias_mask] = 0
        batch_type_embs = batch_type_embs.unsqueeze(2).expand(
            batch, M, K, self.emb_size)

        res = DottedDict(tensor=batch_type_embs,
                         pos_in_sent=entity_package.pos_in_sent,
                         alias_indices=entity_package.alias_indices,
                         mask=entity_package.mask,
                         normalize=True)
        entity_embs.append(res)
        return entity_embs, batch_type_pred
예제 #13
0
def createdBoolDottedDict(d_dict):
    if (type(d_dict) is DottedDict) or (type(d_dict) is dict):
        d_dict = DottedDict(d_dict)
    if type(d_dict) is str and is_json(d_dict):
        d_dict = DottedDict(ujson.loads(d_dict))
    if type(d_dict) is DottedDict:
        for k in d_dict:
            if d_dict[k] == "True":
                d_dict[k] = True
            elif d_dict[k] == "False":
                d_dict[k] = False
            elif (type(d_dict[k]) is DottedDict) or (type(d_dict[k]) is dict) or (type(d_dict[k]) is str and is_json(d_dict[k])):
                d_dict[k] = createdBoolDottedDict(d_dict[k])
            elif type(d_dict[k]) is list:
                for i in range(len(d_dict[k])):
                    d_dict[k][i] = createdBoolDottedDict(d_dict[k][i])
    return d_dict
예제 #14
0
 def test_forward(self):
     self.alias_entity_table = AliasEntityTable(
         DottedDict(self.config), self.entity_symbols)
     # idx 0 is multi word alias 2, idx 1 is alias 1
     actual_indices = self.alias_entity_table.forward(
         torch.tensor([[[0, 1]]]))
     # 0 is for non-candidate, -1 is for padded value
     expected_tensor = torch.tensor([[[[0, 2, 1, 4], [0, 1, 4, -1]]]])
     assert torch.equal(actual_indices, expected_tensor)
예제 #15
0
 def _package(self, tensor, pos_in_sent, alias_indices, mask):
     packed_emb = DottedDict(tensor=tensor,
                             pos_in_sent=pos_in_sent,
                             alias_indices=alias_indices,
                             mask=mask,
                             normalize=self.normalize,
                             key=self.get_key(),
                             dim=self.get_dim())
     return packed_emb
예제 #16
0
 def setUp(self):
     self.entity_symbols = EntitySymbolsSubclass()
     self.hidden_size = 30
     self.learned_embedding_size = 50
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.args.model_config.hidden_size = self.hidden_size
     emb_args = DottedDict({'learned_embedding_size': self.learned_embedding_size})
     self.learned_emb = LearnedEntityEmb(main_args=self.args, emb_args=emb_args,
         model_device='cpu', entity_symbols=self.entity_symbols,
         word_symbols=None, word_emb=None, key="learned")
예제 #17
0
 def test_avg_unk(self):
     entity_ids = torch.tensor([[[3]]])
     entity_package = DottedDict(tensor=entity_ids,
                                 pos_in_sent=get_pos_in_sent(),
                                 alias_indices=None,
                                 mask=None,
                                 normalize=False,
                                 key="key",
                                 dim=0)
     actual_embs = self.average_titles(entity_package)
     expected_embs = torch.tensor([[[[0.5, 0.5] + [0] * 24]]])
     assert torch.equal(actual_embs, expected_embs)
예제 #18
0
    def forward(self, sent_embedding, alias_idx_pair_sent, entity_embedding,
                entity_mask):
        batch_size = sent_embedding.tensor.shape[0]
        # Create list of all entity tensors
        alias_list = []
        alias_indices = None
        for embedding in entity_embedding:
            # Entity shape: batch_size x M x K x embedding_dim
            assert (embedding.tensor.shape[0] == batch_size)
            assert (embedding.tensor.shape[1] == self.M)
            assert (embedding.tensor.shape[2] == self.K)
            emb = embedding.tensor
            if alias_indices is not None:
                assert torch.equal(
                    alias_indices, embedding.alias_indices
                ), "Alias indices should not be different between embeddings in embCombiner"
            alias_indices = embedding.alias_indices
            # Normalize input embeddings
            if embedding.normalize:
                emb = model_utils.normalize_matrix(emb, dim=3)
                assert not torch.isnan(emb).any()
                assert not torch.isinf(emb).any()
            alias_list.append(emb)
        alias_tensor = self.linear_layers['project_embedding'](alias_list)
        alias_tensor_first = self.position_enc['alias'](
            alias_tensor,
            alias_idx_pair_sent[0].transpose(0,
                                             1).repeat(self.K, 1,
                                                       1).transpose(0,
                                                                    2).long())
        alias_tensor_last = self.position_enc['alias'](
            alias_tensor,
            alias_idx_pair_sent[1].transpose(0,
                                             1).repeat(self.K, 1,
                                                       1).transpose(0,
                                                                    2).long())
        alias_tensor = self.position_enc['alias_position_cat'](
            [alias_tensor_first, alias_tensor_last])

        proj_ent_embedding = DottedDict(
            tensor=alias_tensor,
            # Position of entities in sentence
            pos_in_sent=alias_idx_pair_sent,
            # Indexes of aliases
            alias_indices=alias_indices,
            # All entity embeddings have the same mask currently
            mask=embedding.mask,
            # Do not normalize this embedding if normalized is called
            normalize=False,
            dim=alias_tensor.shape[-1])
        return sent_embedding, proj_ent_embedding
예제 #19
0
 def forward(self, word_package):
     attention_mask = word_package.mask
     word_vectors = word_package.tensor
     head_mask = [None] * self.num_layers
     if self.requires_grad:
         output = self.encoder(word_vectors, attention_mask, head_mask)[0]
     else:
         with torch.no_grad():
             output = self.encoder(word_vectors, attention_mask,
                                   head_mask)[0]
     emb = DottedDict(tensor=output,
                      downstream_mask=word_package.downstream_mask,
                      key=self.get_key(),
                      dim=self.get_dim())
     return emb
예제 #20
0
 def test_forward_dimension(self):
     entity_ids = torch.tensor([[[0, 1], [2, 2]]])
     entity_package = DottedDict(tensor=entity_ids,
                                 pos_in_sent=get_pos_in_sent(),
                                 alias_indices=None,
                                 mask=None,
                                 normalize=False,
                                 key="key",
                                 dim=0)
     actual_out = self.learned_emb(entity_package,
                                   batch_prepped_data={},
                                   batch_on_the_fly_data={},
                                   sent_emb=None)
     assert ([
         entity_ids.shape[0], entity_ids.shape[1], entity_ids.shape[2],
         self.learned_embedding_size
     ] == list(actual_out.tensor.size()))
예제 #21
0
 def forward(self, word_indices, requires_grad=None, token_id=0):
     if requires_grad is None:
         requires_grad = self.requires_grad
     token_type_ids = torch.ones_like(word_indices) * token_id
     if requires_grad:
         out = self.embeddings(word_indices, token_type_ids=token_type_ids)
     else:
         with torch.no_grad():
             out = self.embeddings(word_indices,
                                   token_type_ids=token_type_ids)
     packed_emb = DottedDict(
         tensor=out,
         mask=self.get_bert_mask(word_indices),
         downstream_mask=self.get_downstream_mask(word_indices),
         key=self.get_key(),
         dim=self.get_dim())
     return packed_emb
예제 #22
0
파일: model.py 프로젝트: syyunn/bootleg
 def forward(self, alias_idx_pair_sent, word_indices, alias_indices,
             entity_indices, batch_prepped_data, batch_on_the_fly_data):
     # mask out padded candidates (last row)
     mask = entity_indices == -1
     entity_indices = torch.where(
         entity_indices >= 0, entity_indices,
         (torch.ones_like(entity_indices, dtype=torch.long) *
          (self.num_entities_with_pad_and_nocand - 1)))
     entity_package = DottedDict(tensor=entity_indices,
                                 pos_in_sent=alias_idx_pair_sent,
                                 alias_indices=alias_indices,
                                 mask=mask,
                                 normalize=False,
                                 key=None,
                                 dim=None)
     sent_emb, entity_embs = self.emb_layer(word_indices, entity_package,
                                            batch_prepped_data,
                                            batch_on_the_fly_data)
     res = self.attn_network(alias_idx_pair_sent, sent_emb, entity_embs,
                             batch_prepped_data, batch_on_the_fly_data)
     return res, entity_package, None
예제 #23
0
 def setUp(self):
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.entity_symbols = EntitySymbolsSubclass()
     self.hidden_size = 30
     self.learned_embedding_size = 50
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.regularization_csv = os.path.join(
         self.args.data_config.data_dir, "test_reg.csv"
     )
     self.static_emb = os.path.join(self.args.data_config.data_dir, "static_emb.pt")
     self.qid2topkeid = os.path.join(
         self.args.data_config.data_dir, "test_eid2topk.json"
     )
     self.args.model_config.hidden_size = self.hidden_size
     self.args.data_config.ent_embeddings[0]["args"] = DottedDict(
         {"learned_embedding_size": self.learned_embedding_size}
     )
예제 #24
0
    def test_adding_title_nobert(self):
        self.args.data_config.ent_embeddings.append(
            DottedDict(
                {
                    "key": "title1",
                    "load_class": "TitleEmb",
                    "normalize": False,
                    "args": {"proj": 5},
                }
            )
        )
        (
            task_flows,
            module_pool,
            module_device_dict,
            extra_bert_embedding_layers,
            to_add_to_payload,
            total_sizes,
        ) = get_embedding_tasks(self.args, self.entity_symbols)

        gold_module_device_dict = {"learned5": -1}

        gold_extra_bert_embedding_layers = []

        gold_to_add_to_payload = [
            ("embedding_learned1", 0),
            ("embedding_learned2", 0),
            ("embedding_learned3", 0),
            ("embedding_learned4", 0),
            ("embedding_learned5", 0),
            ("embedding_title1", 0),
        ]

        gold_total_sizes = {
            "learned1": 5,
            "learned2": 5,
            "learned3": 5,
            "learned4": 5,
            "learned5": 5,
            "title1": 5,
        }

        gold_task_flows = [
            {
                "name": f"embedding_learned1",
                "module": "learned1",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned2",
                "module": "learned2",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned3",
                "module": "learned3",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned4",
                "module": "learned4",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned5",
                "module": "learned5",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_title1",
                "module": "title1",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
        ]
        # Asserts that the order is the same
        self.assertEqual(to_add_to_payload, gold_to_add_to_payload)
        self.assertEqual(extra_bert_embedding_layers, gold_extra_bert_embedding_layers)
        self.assertDictEqual(module_device_dict, gold_module_device_dict)
        self.assertDictEqual(total_sizes, gold_total_sizes)
        assert len(task_flows) == len(gold_task_flows)
        for li, r in zip(task_flows, gold_task_flows):
            self.assertDictEqual(li, r)
예제 #25
0
    def test_adding_title(self):
        self.args.data_config.ent_embeddings.append(
            DottedDict(
                {
                    "key": "title1",
                    "load_class": "TitleEmb",
                    "send_through_bert": True,
                    "normalize": False,
                    "args": {"proj": 5},
                }
            )
        )
        (
            task_flows,
            module_pool,
            module_device_dict,
            extra_bert_embedding_layers,
            to_add_to_payload,
            total_sizes,
        ) = get_embedding_tasks(self.args, self.entity_symbols)

        gold_module_device_dict = {"learned5": -1}

        gold_extra_bert_embedding_layers = [TitleEmbMock()]

        gold_to_add_to_payload = [
            ("embedding_learned1", 0),
            ("embedding_learned2", 0),
            ("embedding_learned3", 0),
            ("embedding_learned4", 0),
            ("embedding_learned5", 0),
            ("bert", 2),
        ]

        gold_total_sizes = {
            "learned1": 5,
            "learned2": 5,
            "learned3": 5,
            "learned4": 5,
            "learned5": 5,
            "title1": 5,
        }

        gold_task_flows = [
            {
                "name": f"embedding_learned1",
                "module": "learned1",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned2",
                "module": "learned2",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned3",
                "module": "learned3",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned4",
                "module": "learned4",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
            {
                "name": f"embedding_learned5",
                "module": "learned5",
                "inputs": [
                    ("_input_", "entity_cand_eid"),
                    (
                        "_input_",
                        "batch_on_the_fly_kg_adj",
                    ),  # special kg adjacency embedding prepped in dataloader
                ],
            },
        ]
        # Asserts that the order is the same
        self.assertEqual(to_add_to_payload, gold_to_add_to_payload)
        assert len(extra_bert_embedding_layers) == len(gold_extra_bert_embedding_layers)
        for i_l, i_r in zip(
            extra_bert_embedding_layers, gold_extra_bert_embedding_layers
        ):
            # These are classes so we can't do == but we can check other properties are correct
            assert type(i_l) is TitleEmb
            self.assertEqual(i_l.key, i_r.key)
            self.assertEqual(i_l.cpu, i_r.cpu)
            self.assertEqual(i_l.normalize, i_r.normalize)
            self.assertEqual(i_l.dropout1d_perc, i_r.dropout1d_perc)
            self.assertEqual(i_l.dropout2d_perc, i_r.dropout2d_perc)
        self.assertDictEqual(module_device_dict, gold_module_device_dict)
        self.assertDictEqual(total_sizes, gold_total_sizes)
        assert len(task_flows) == len(gold_task_flows)
        for li, r in zip(task_flows, gold_task_flows):
            self.assertDictEqual(li, r)