def average_titles(self, entity_package, subset_idx, batch_title_ids, mask_subset, batch_title_emb, sent_emb): # mask of True means it's valid so we invert embs = selective_avg(ids=batch_title_ids, mask=~mask_subset, embeds=batch_title_emb) return embs
def _selective_avg_types(self, type_ids, embeds): mask = ((type_ids < (self.num_types_with_pad_and_unk - 1)) & (type_ids > 0)) average_val = selective_avg(type_ids, mask, embeds) num_unk_types = ((type_ids == 0).sum(3) == type_ids.shape[-1]) unk_types = torch.where(num_unk_types.unsqueeze(3), embeds[:, :, :, 0], torch.zeros_like(average_val)) return average_val + unk_types
def average_titles(self, subset_mask, title_emb): """Take the average title embedding, respecting unk embeddings. Args: subset_mask: mask of unk embeddings (True means we remove) title_emb: title embedding Returns: average title embedding """ # subset_mask is downstream Pytorch mask where True means remove. Averages requires True to mean we keep embs = model_utils.selective_avg(mask=~subset_mask, embeds=title_emb) return embs
def _selective_avg_types(self, type_ids, embeds): """Selects the average embedding, ignoring padded types. Args: type_ids: type ids embeds: embeddings Returns: average embedding """ # mask of True means keep in the average mask = (type_ids < (self.num_types_with_pad_and_unk - 1)) & (type_ids > 0) average_val = selective_avg(mask, embeds) num_unk_types = (type_ids == 0).sum(3) == type_ids.shape[-1] unk_types = torch.where( num_unk_types.unsqueeze(3), embeds[:, :, :, 0], torch.zeros_like(average_val), ) return average_val + unk_types
def average_rels(self, entity_package, rel_ids, batch_rel_emb, sent_emb): """For each candidate, average the relation embs it shares with other candidates in the sentence.""" return selective_avg(ids=rel_ids, mask=rel_ids > 0, embeds=batch_rel_emb)