def test_payload(self): self.args.data_config.ent_embeddings.extend( [DottedDict({"key": "kg_emb"}), DottedDict({"key": "kg_rel"})] ) emb_payload = EmbeddingPayloadMock(self.args, self.entity_symbols) emb_payload.linear_layers = nn.ModuleDict() emb_payload.linear_layers["project_embedding"] = NoopCat() emb_payload.position_enc = nn.ModuleDict() emb_payload.position_enc["alias"] = Noop() emb_payload.position_enc["alias_last_token"] = Noop() emb_payload.position_enc["alias_position_cat"] = NoopTakeFirst() batch_size = 3 # Max aliases is set to 5 in the config so we need one position per alias per batch start_idx_pair = torch.zeros(batch_size, emb_payload.M) end_idx_pair = torch.zeros(batch_size, emb_payload.M) entity_embedding = { "kg_emb": torch.randn(batch_size, emb_payload.M, emb_payload.K, 5), "kg_rel": torch.randn(batch_size, emb_payload.M, emb_payload.K, 6), } alias_list = emb_payload( start_idx_pair, end_idx_pair, entity_embedding["kg_emb"], entity_embedding["kg_rel"], ) # The embeddings have been concatenated together assert torch.isclose( alias_list, torch.cat([entity_embedding["kg_emb"], entity_embedding["kg_rel"]], dim=3), ).all()
def test_kg_norm(self): emb_sizes = {"kg_emb": 5, "kg_rel": 6} sent_emb_size = 10 self.args.data_config.ent_embeddings.extend( [DottedDict({"key": "kg_emb"}), DottedDict({"key": "kg_rel"})]) emb_combiner = EmbCombinerProj(self.args, emb_sizes, sent_emb_size, self.word_symbols, self.entity_symbols) emb_combiner.linear_layers = nn.ModuleDict() emb_combiner.linear_layers['project_embedding'] = NoopCat() emb_combiner.position_enc = nn.ModuleDict() emb_combiner.position_enc['alias'] = Noop() emb_combiner.position_enc['alias_last_token'] = Noop() emb_combiner.position_enc['alias_position_cat'] = NoopTakeFirst() batch_size = 3 num_words = 5 sent_embedding = DottedDict(tensor=torch.randn(batch_size, num_words, sent_emb_size), downstream_mask=None, mask=None, key="sent_emb", dim=10) # Max aliases is set to 5 in the config so we need one position per alias per batch alias_idx_pair_sent = [ torch.tensor([[0] * emb_combiner.M, [0] * emb_combiner.M]), torch.tensor([[0] * emb_combiner.M, [0] * emb_combiner.M]) ] entity_embedding = [ DottedDict(tensor=torch.randn(batch_size, emb_combiner.M, emb_combiner.K, 5), pos_in_sent=get_pos_in_sent(), alias_indices=None, mask=None, normalize=True, key="kg_emb", dim=5), DottedDict(tensor=torch.randn(batch_size, emb_combiner.M, emb_combiner.K, 6), pos_in_sent=get_pos_in_sent(), alias_indices=None, mask=None, normalize=False, key="kg_rel", dim=6) ] norm_1 = entity_embedding[0].tensor.norm(p=2, dim=3) norm_2 = entity_embedding[1].tensor.norm(p=2, dim=3) entity_mask = None _, res_package = emb_combiner(sent_embedding, alias_idx_pair_sent, entity_embedding, entity_mask) alias_list = res_package.tensor # The embeddings have been normalized and concatenated together assert torch.isclose(alias_list[:, :, :, :5].norm(p=2, dim=3), torch.ones_like(norm_1)).all() assert torch.isclose(alias_list[:, :, :, 5:].norm(p=2, dim=3), norm_2).all()
def setUp(self) -> None: self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.args.data_config.ent_embeddings = [ DottedDict( { "key": "learned1", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned2", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned3", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned4", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned5", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), ] self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbolsSubclass()
def setUp(self) -> None: self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_embeddings.json" ) emmental.init(log_dir="test/temp_log", config=self.args) if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path) self.args.data_config.ent_embeddings = [ DottedDict( { "key": "learned1", "load_class": "LearnedEntityEmb", "dropout1d": 0.5, "args": {"learned_embedding_size": 5, "tail_init": False}, } ), DottedDict( { "key": "learned2", "dropout2d": 0.5, "load_class": "LearnedEntityEmb", "args": {"learned_embedding_size": 5, "tail_init": False}, } ), DottedDict( { "key": "learned3", "load_class": "LearnedEntityEmb", "freeze": True, "args": {"learned_embedding_size": 5, "tail_init": False}, } ), DottedDict( { "key": "learned4", "load_class": "LearnedEntityEmb", "normalize": False, "args": {"learned_embedding_size": 5, "tail_init": False}, } ), DottedDict( { "key": "learned5", "load_class": "LearnedEntityEmb", "cpu": True, "args": {"learned_embedding_size": 5, "tail_init": False}, } ), ] self.tokenizer = load_tokenizer() self.entity_symbols = EntitySymbolsSubclass()
def forward(self, word_indices, requires_grad=None): if requires_grad is None: requires_grad = self.requires_grad (batch_size, seq_length) = word_indices.shape # num_words_with_pad_unk-1 because index starts at 0 word_indices_pos = torch.where( word_indices >= 0, word_indices, (torch.ones_like(word_indices, dtype=torch.long) * (self.num_words_with_pad_unk - 1))) if requires_grad: word_vectors = self.word_embedding(word_indices_pos) else: with torch.no_grad(): word_vectors = self.word_embedding(word_indices_pos) if self.use_proj: word_vectors = self.proj(word_vectors) word_vectors = self.position_words( word_vectors, torch.arange(0, word_indices.shape[1]).repeat(batch_size, 1)) word_vectors = self.layer_norm(word_vectors) word_vectors = self.dropout(word_vectors) packed_emb = DottedDict( tensor=word_vectors, # The mask for the standard word embedding is the same as the downstream mask mask=self.get_downstream_mask(word_indices), downstream_mask=self.get_downstream_mask(word_indices), key=self.get_key(), dim=self.get_dim()) return packed_emb
def test_end2end_withtitle_accstep(self): self.args.data_config.ent_embeddings.append( DottedDict({ "key": "title1", "load_class": "TitleEmb", "send_through_bert": True, "args": { "proj": 6 }, })) # Just setting this for testing pipelines self.args.data_config.eval_accumulation_steps = 2 self.args.run_config.dataset_threads = 2 scores = run_model(mode="train", config=self.args) assert type(scores) is dict assert len(scores) > 0 assert scores["model/all/train/loss"] < 0.08 self.args["model_config"][ "model_path"] = f"{emmental.Meta.log_path}/last_model.pth" emmental.Meta.config["model_config"][ "model_path"] = f"{emmental.Meta.log_path}/last_model.pth" result_file, out_emb_file = run_model(mode="dump_embs", config=self.args) assert os.path.exists(result_file) results = [ujson.loads(li) for li in open(result_file)] assert 18 == len(results) # 18 total sentences assert set([f for li in results for f in li["ctx_emb_ids"] ]) == set(range(51)) # 38 total mentions assert os.path.exists(out_emb_file)
def forward(self, word_indices, entity_package, batch_prepped_data, batch_on_the_fly_data): word_package = self.word_emb(word_indices) sent_emb = self.sent_emb(word_package) sent_tensor = self.project_sent(sent_emb.tensor) # WordPackages have two masks. One is used in the sentence embedding module (param mask) and one is used in our attention network (param downstream mask). # At this point, the mask we always want to use is the downstream mask (as we are beyond sentence embedding stage). # Hence we sent both mask and downstream mask to be the same. sent_emb = DottedDict(tensor=sent_tensor, mask=sent_emb.downstream_mask, key=sent_emb.key, dim=sent_emb.dim) entity_embs = [] for entity_emb in self.entity_embs.values(): forward_package = entity_package forward_package.key = entity_emb.key emb_package = entity_emb(forward_package, batch_prepped_data, batch_on_the_fly_data, sent_emb) # If some embeddings do not want to be added to the payload, they will return an empty emb_package # This happens for the kg bias (KGIndices) class for our kg attention network if len(emb_package) == 0: continue emb_package.tensor = model_utils.emb_2d_dropout( entity_emb.training, entity_emb.mask_perc, emb_package.tensor) entity_embs.append(emb_package) return sent_emb, entity_embs
def forward(self, word_package): attention_mask = word_package.mask word_vectors = word_package.tensor batch_size = word_vectors.shape[0] word_vectors = word_vectors.transpose(0, 1) out = word_vectors if self.requires_grad: for i in range(self.num_layers): out, weights = self.attention_modules[ f"stage_{i}_self_sentence"](out, key_mask=attention_mask, attn_mask=None) self.attention_weights[f"layer_{i}_sent"] = weights else: with torch.no_grad(): for i in range(self.num_layers): out, weights = self.attention_modules[ f"stage_{i}_self_sentence"](out, key_mask=attention_mask, attn_mask=None) self.attention_weights[f"layer_{i}_sent"] = weights out = out.transpose(0, 1) assert out.shape[0] == batch_size emb = DottedDict( tensor=out, # Make the main mask also be the downstream mask as this is post sentence embedding downstream_mask=word_package.downstream_mask, key=self.get_key(), dim=self.get_dim()) return emb
def forward(self, alias_idx_pair_sent, word_indices, alias_indices, entity_indices, batch_prepped_data, batch_on_the_fly_data): # mask out padded candidates mask = entity_indices == -1 entity_indices = torch.where( entity_indices >= 0, entity_indices, (torch.ones_like(entity_indices, dtype=torch.long) * (self.num_entities_with_pad_and_nocand - 1))) entity_package = DottedDict(tensor=entity_indices, pos_in_sent=alias_idx_pair_sent, alias_indices=alias_indices, mask=mask) sent_emb, entity_embs = self.emb_layer(word_indices, entity_package, batch_prepped_data, batch_on_the_fly_data) sent_emb, entity_embs = self.emb_combiner( sent_embedding=sent_emb, alias_idx_pair_sent=alias_idx_pair_sent, entity_embedding=entity_embs, entity_mask=mask) context_matrix_dict, backbone_out = self.attn_network( alias_idx_pair_sent, sent_emb, entity_embs, batch_prepped_data, batch_on_the_fly_data) res, final_entity_embs = self.slice_heads( context_matrix_dict, alias_idx_pair_sent=alias_idx_pair_sent, entity_pack=entity_package, sent_emb=sent_emb, batch_prepped=batch_prepped_data, raw_entity_emb=entity_embs) # update output dictionary with backbone out res[DISAMBIG].update(backbone_out[DISAMBIG]) return res, entity_package, final_entity_embs
def test_setup_incand(self): self.config["data_config"]["train_in_candidates"] = True self.alias_entity_table = AliasEntityTable( DottedDict(self.config["data_config"]), self.entity_symbols ) gold_alias2entity_table = torch.tensor( [ [1, 4, -1], [1, -1, -1], [4, 3, 2], [2, 1, 4], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [4, 3, 2], [-1, -1, -1], [-1, -1, -1], ] ) assert torch.equal( gold_alias2entity_table.long(), self.alias_entity_table.alias2entity_table.long(), )
def test_setup_notincand(self): self.alias_entity_table = AliasEntityTable( DottedDict(self.config["data_config"]), self.entity_symbols ) gold_alias2entity_table = torch.tensor( [ [0, 1, 4, -1], [0, 1, -1, -1], [0, 4, 3, 2], [0, 2, 1, 4], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [0, 4, 3, 2], [-1, -1, -1, -1], [-1, -1, -1, -1], ] ) assert torch.equal( gold_alias2entity_table.long(), self.alias_entity_table.alias2entity_table.long(), )
def forward(self, sent_emb, entity_package, entity_embs): batch, M, K = entity_package.tensor.shape # Get alias tensor and expand to be for each candidate for soft attn alias_word_tensor = model_utils.select_alias_word_sent( entity_package.pos_in_sent, sent_emb, index=0) alias_mask = entity_package.alias_indices == -1 # batch x M x num_types batch_type_pred = self.prediction(alias_word_tensor) batch_type_weights = self.type_softmax(batch_type_pred) # batch x M x emb_size batch_type_embs = torch.matmul(batch_type_weights, self.type_embedding.unsqueeze(0)) # mask out unk alias embeddings batch_type_embs[alias_mask] = 0 batch_type_embs = batch_type_embs.unsqueeze(2).expand( batch, M, K, self.emb_size) res = DottedDict(tensor=batch_type_embs, pos_in_sent=entity_package.pos_in_sent, alias_indices=entity_package.alias_indices, mask=entity_package.mask, normalize=True) entity_embs.append(res) return entity_embs, batch_type_pred
def createdBoolDottedDict(d_dict): if (type(d_dict) is DottedDict) or (type(d_dict) is dict): d_dict = DottedDict(d_dict) if type(d_dict) is str and is_json(d_dict): d_dict = DottedDict(ujson.loads(d_dict)) if type(d_dict) is DottedDict: for k in d_dict: if d_dict[k] == "True": d_dict[k] = True elif d_dict[k] == "False": d_dict[k] = False elif (type(d_dict[k]) is DottedDict) or (type(d_dict[k]) is dict) or (type(d_dict[k]) is str and is_json(d_dict[k])): d_dict[k] = createdBoolDottedDict(d_dict[k]) elif type(d_dict[k]) is list: for i in range(len(d_dict[k])): d_dict[k][i] = createdBoolDottedDict(d_dict[k][i]) return d_dict
def test_forward(self): self.alias_entity_table = AliasEntityTable( DottedDict(self.config), self.entity_symbols) # idx 0 is multi word alias 2, idx 1 is alias 1 actual_indices = self.alias_entity_table.forward( torch.tensor([[[0, 1]]])) # 0 is for non-candidate, -1 is for padded value expected_tensor = torch.tensor([[[[0, 2, 1, 4], [0, 1, 4, -1]]]]) assert torch.equal(actual_indices, expected_tensor)
def _package(self, tensor, pos_in_sent, alias_indices, mask): packed_emb = DottedDict(tensor=tensor, pos_in_sent=pos_in_sent, alias_indices=alias_indices, mask=mask, normalize=self.normalize, key=self.get_key(), dim=self.get_dim()) return packed_emb
def setUp(self): self.entity_symbols = EntitySymbolsSubclass() self.hidden_size = 30 self.learned_embedding_size = 50 self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.args.model_config.hidden_size = self.hidden_size emb_args = DottedDict({'learned_embedding_size': self.learned_embedding_size}) self.learned_emb = LearnedEntityEmb(main_args=self.args, emb_args=emb_args, model_device='cpu', entity_symbols=self.entity_symbols, word_symbols=None, word_emb=None, key="learned")
def test_avg_unk(self): entity_ids = torch.tensor([[[3]]]) entity_package = DottedDict(tensor=entity_ids, pos_in_sent=get_pos_in_sent(), alias_indices=None, mask=None, normalize=False, key="key", dim=0) actual_embs = self.average_titles(entity_package) expected_embs = torch.tensor([[[[0.5, 0.5] + [0] * 24]]]) assert torch.equal(actual_embs, expected_embs)
def forward(self, sent_embedding, alias_idx_pair_sent, entity_embedding, entity_mask): batch_size = sent_embedding.tensor.shape[0] # Create list of all entity tensors alias_list = [] alias_indices = None for embedding in entity_embedding: # Entity shape: batch_size x M x K x embedding_dim assert (embedding.tensor.shape[0] == batch_size) assert (embedding.tensor.shape[1] == self.M) assert (embedding.tensor.shape[2] == self.K) emb = embedding.tensor if alias_indices is not None: assert torch.equal( alias_indices, embedding.alias_indices ), "Alias indices should not be different between embeddings in embCombiner" alias_indices = embedding.alias_indices # Normalize input embeddings if embedding.normalize: emb = model_utils.normalize_matrix(emb, dim=3) assert not torch.isnan(emb).any() assert not torch.isinf(emb).any() alias_list.append(emb) alias_tensor = self.linear_layers['project_embedding'](alias_list) alias_tensor_first = self.position_enc['alias']( alias_tensor, alias_idx_pair_sent[0].transpose(0, 1).repeat(self.K, 1, 1).transpose(0, 2).long()) alias_tensor_last = self.position_enc['alias']( alias_tensor, alias_idx_pair_sent[1].transpose(0, 1).repeat(self.K, 1, 1).transpose(0, 2).long()) alias_tensor = self.position_enc['alias_position_cat']( [alias_tensor_first, alias_tensor_last]) proj_ent_embedding = DottedDict( tensor=alias_tensor, # Position of entities in sentence pos_in_sent=alias_idx_pair_sent, # Indexes of aliases alias_indices=alias_indices, # All entity embeddings have the same mask currently mask=embedding.mask, # Do not normalize this embedding if normalized is called normalize=False, dim=alias_tensor.shape[-1]) return sent_embedding, proj_ent_embedding
def forward(self, word_package): attention_mask = word_package.mask word_vectors = word_package.tensor head_mask = [None] * self.num_layers if self.requires_grad: output = self.encoder(word_vectors, attention_mask, head_mask)[0] else: with torch.no_grad(): output = self.encoder(word_vectors, attention_mask, head_mask)[0] emb = DottedDict(tensor=output, downstream_mask=word_package.downstream_mask, key=self.get_key(), dim=self.get_dim()) return emb
def test_forward_dimension(self): entity_ids = torch.tensor([[[0, 1], [2, 2]]]) entity_package = DottedDict(tensor=entity_ids, pos_in_sent=get_pos_in_sent(), alias_indices=None, mask=None, normalize=False, key="key", dim=0) actual_out = self.learned_emb(entity_package, batch_prepped_data={}, batch_on_the_fly_data={}, sent_emb=None) assert ([ entity_ids.shape[0], entity_ids.shape[1], entity_ids.shape[2], self.learned_embedding_size ] == list(actual_out.tensor.size()))
def forward(self, word_indices, requires_grad=None, token_id=0): if requires_grad is None: requires_grad = self.requires_grad token_type_ids = torch.ones_like(word_indices) * token_id if requires_grad: out = self.embeddings(word_indices, token_type_ids=token_type_ids) else: with torch.no_grad(): out = self.embeddings(word_indices, token_type_ids=token_type_ids) packed_emb = DottedDict( tensor=out, mask=self.get_bert_mask(word_indices), downstream_mask=self.get_downstream_mask(word_indices), key=self.get_key(), dim=self.get_dim()) return packed_emb
def forward(self, alias_idx_pair_sent, word_indices, alias_indices, entity_indices, batch_prepped_data, batch_on_the_fly_data): # mask out padded candidates (last row) mask = entity_indices == -1 entity_indices = torch.where( entity_indices >= 0, entity_indices, (torch.ones_like(entity_indices, dtype=torch.long) * (self.num_entities_with_pad_and_nocand - 1))) entity_package = DottedDict(tensor=entity_indices, pos_in_sent=alias_idx_pair_sent, alias_indices=alias_indices, mask=mask, normalize=False, key=None, dim=None) sent_emb, entity_embs = self.emb_layer(word_indices, entity_package, batch_prepped_data, batch_on_the_fly_data) res = self.attn_network(alias_idx_pair_sent, sent_emb, entity_embs, batch_prepped_data, batch_on_the_fly_data) return res, entity_package, None
def setUp(self): emmental.init(log_dir="test/temp_log") if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path) self.entity_symbols = EntitySymbolsSubclass() self.hidden_size = 30 self.learned_embedding_size = 50 self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_embeddings.json" ) self.regularization_csv = os.path.join( self.args.data_config.data_dir, "test_reg.csv" ) self.static_emb = os.path.join(self.args.data_config.data_dir, "static_emb.pt") self.qid2topkeid = os.path.join( self.args.data_config.data_dir, "test_eid2topk.json" ) self.args.model_config.hidden_size = self.hidden_size self.args.data_config.ent_embeddings[0]["args"] = DottedDict( {"learned_embedding_size": self.learned_embedding_size} )
def test_adding_title_nobert(self): self.args.data_config.ent_embeddings.append( DottedDict( { "key": "title1", "load_class": "TitleEmb", "normalize": False, "args": {"proj": 5}, } ) ) ( task_flows, module_pool, module_device_dict, extra_bert_embedding_layers, to_add_to_payload, total_sizes, ) = get_embedding_tasks(self.args, self.entity_symbols) gold_module_device_dict = {"learned5": -1} gold_extra_bert_embedding_layers = [] gold_to_add_to_payload = [ ("embedding_learned1", 0), ("embedding_learned2", 0), ("embedding_learned3", 0), ("embedding_learned4", 0), ("embedding_learned5", 0), ("embedding_title1", 0), ] gold_total_sizes = { "learned1": 5, "learned2": 5, "learned3": 5, "learned4": 5, "learned5": 5, "title1": 5, } gold_task_flows = [ { "name": f"embedding_learned1", "module": "learned1", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned2", "module": "learned2", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned3", "module": "learned3", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned4", "module": "learned4", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned5", "module": "learned5", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_title1", "module": "title1", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, ] # Asserts that the order is the same self.assertEqual(to_add_to_payload, gold_to_add_to_payload) self.assertEqual(extra_bert_embedding_layers, gold_extra_bert_embedding_layers) self.assertDictEqual(module_device_dict, gold_module_device_dict) self.assertDictEqual(total_sizes, gold_total_sizes) assert len(task_flows) == len(gold_task_flows) for li, r in zip(task_flows, gold_task_flows): self.assertDictEqual(li, r)
def test_adding_title(self): self.args.data_config.ent_embeddings.append( DottedDict( { "key": "title1", "load_class": "TitleEmb", "send_through_bert": True, "normalize": False, "args": {"proj": 5}, } ) ) ( task_flows, module_pool, module_device_dict, extra_bert_embedding_layers, to_add_to_payload, total_sizes, ) = get_embedding_tasks(self.args, self.entity_symbols) gold_module_device_dict = {"learned5": -1} gold_extra_bert_embedding_layers = [TitleEmbMock()] gold_to_add_to_payload = [ ("embedding_learned1", 0), ("embedding_learned2", 0), ("embedding_learned3", 0), ("embedding_learned4", 0), ("embedding_learned5", 0), ("bert", 2), ] gold_total_sizes = { "learned1": 5, "learned2": 5, "learned3": 5, "learned4": 5, "learned5": 5, "title1": 5, } gold_task_flows = [ { "name": f"embedding_learned1", "module": "learned1", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned2", "module": "learned2", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned3", "module": "learned3", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned4", "module": "learned4", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, { "name": f"embedding_learned5", "module": "learned5", "inputs": [ ("_input_", "entity_cand_eid"), ( "_input_", "batch_on_the_fly_kg_adj", ), # special kg adjacency embedding prepped in dataloader ], }, ] # Asserts that the order is the same self.assertEqual(to_add_to_payload, gold_to_add_to_payload) assert len(extra_bert_embedding_layers) == len(gold_extra_bert_embedding_layers) for i_l, i_r in zip( extra_bert_embedding_layers, gold_extra_bert_embedding_layers ): # These are classes so we can't do == but we can check other properties are correct assert type(i_l) is TitleEmb self.assertEqual(i_l.key, i_r.key) self.assertEqual(i_l.cpu, i_r.cpu) self.assertEqual(i_l.normalize, i_r.normalize) self.assertEqual(i_l.dropout1d_perc, i_r.dropout1d_perc) self.assertEqual(i_l.dropout2d_perc, i_r.dropout2d_perc) self.assertDictEqual(module_device_dict, gold_module_device_dict) self.assertDictEqual(total_sizes, gold_total_sizes) assert len(task_flows) == len(gold_task_flows) for li, r in zip(task_flows, gold_task_flows): self.assertDictEqual(li, r)