def forward(self, context_ids, context_token_ct, lf_ids, lf_token_ct, target_lf_ids, num_outputs): device_str = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size, num_context_ids = context_ids.size() _, max_output_size, max_lf_len = lf_ids.size() context_mask = np.ones([batch_size, num_context_ids], dtype=int) lf_mask = np.ones([batch_size, max_output_size, max_lf_len], dtype=int) for i in range(batch_size): context_mask[i, context_token_ct[i]:] = 0 for lf_idx in range(max_output_size): lf_mask[i, lf_idx, lf_token_ct[i, lf_idx]:] = 0 context_mask = torch.FloatTensor(context_mask).to(device_str) lf_mask = torch.FloatTensor(lf_mask).to(device_str) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(device_str) # context_position_ids = torch.arange(0, num_context_ids).unsqueeze(0).repeat(batch_size, 1).to(device_str).long() encoded_context = self.encode(context_ids, context_mask, context_token_ct) lf_ids_flat = lf_ids.view(batch_size * max_output_size, max_lf_len) lf_mask_flat = lf_mask.view(batch_size * max_output_size, max_lf_len) lf_token_ct_flat = lf_token_ct.view(batch_size * max_output_size) # lf_position_ids_flat = torch.arange(0, max_lf_len).unsqueeze(0).repeat( # batch_size * max_output_size, 1).to(device_str).long() encoded_lfs_flat = self.encode(lf_ids_flat, lf_mask_flat, lf_token_ct_flat) encoded_lfs = encoded_lfs_flat.view(batch_size, max_output_size, -1) encoded_context_tiled = encoded_context.unsqueeze(1).repeat( 1, max_output_size, 1) score = nn.CosineSimilarity(dim=-1)(encoded_context_tiled, encoded_lfs) score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids
def encode_context(self, sf_ids, context_ids, num_contexts): batch_size, num_context_ids = context_ids.size() # Mask padded context ids mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) sf_mu, sf_sigma = self.encoder(sf_ids, context_ids, mask, token_mask_p=None) return sf_mu, sf_sigma
def forward(self, center_ids, center_metadata_ids, context_ids, context_metadata_ids, neg_ids, neg_metadata_ids, num_contexts, num_metadata_samples=None): """ :param center_ids: batch_size :param center_metadata_ids: batch_size :param context_ids: batch_size, 2 * context_window :param context_metadata_ids: batch_size, 2 * context_window, metadata_samples :param neg_ids: batch_size, 2 * context_window :param neg_metadata_ids: batch_size, 2 * context_window, metadata_samples :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids device = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(device) m_samples = context_metadata_ids.size()[-1] assert m_samples == num_metadata_samples # Compute center words mu_center_q, sigma_center_q, _ = self.encoder(center_ids, center_metadata_ids, context_ids, mask) mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1) sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids * m_samples, 1) mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1) sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids * m_samples, -1) # Compute decoded representations of (w, d), E(c), E(n) mu_center, sigma_center = self.decoder(center_ids, center_metadata_ids) mu_pos, sigma_pos = self._compute_marginal(context_ids, context_metadata_ids) mu_neg, sigma_neg = self._compute_marginal(neg_ids, neg_metadata_ids) # Flatten positive context mu_pos_flat = mu_pos.view(batch_size * num_context_ids * m_samples, -1) sigma_pos_flat = sigma_pos.view(batch_size * num_context_ids * m_samples, -1) # Flatten negative context mu_neg_flat = mu_neg.view(batch_size * num_context_ids * m_samples, -1) sigma_neg_flat = sigma_neg.view(batch_size * num_context_ids * m_samples, -1) # Compute KL-divergence between center words and negative and reshape kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat) kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat) kl_pos = kl_pos_flat.view(batch_size, num_context_ids, m_samples).mean(-1) kl_neg = kl_neg_flat.view(batch_size, num_context_ids, m_samples).mean(-1) hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0) hinge_loss.masked_fill_(mask, 0) hinge_loss = hinge_loss.sum(1) recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1) return hinge_loss.mean(), recon_loss.mean()
def forward(self, sf_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, num_outputs): """ :param sf_ids: batch_size :param context_ids: batch_size, num_context_ids :param lf_ids: batch_size, max_output_size, max_lf_len :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids :param target_lf_ids: batch_size :param num_outputs: batch_size :return: """ batch_size, num_context_ids = context_ids.size() max_output_size = lf_ids.size()[1] output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs) # First thing is to pass the SF with the context to the encoder mask = torch.BoolTensor(torch.Size([batch_size, num_context_ids])) mask.fill_(0) sf_mu, sf_sigma = self.encoder(sf_ids, context_ids, mask) # Next is to get prior representations for each LF in lf_ids lf_mu, lf_sigma = self._compute_priors(lf_ids) # Summarize LFs normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum( -2) / normalizer # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view( batch_size * max_output_size, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size, 1).view(batch_size * max_output_size, -1) lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1) lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1) kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size) # min_kl, max_kl = kl.min(-1)[0], kl.max(-1)[0] # normalizer = (max_kl - min_kl).unsqueeze(-1) # numerator = max_kl.unsqueeze(-1).repeat(1, kl.size()[-1]) - kl # score = numerator # / normalizer score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids
def forward(self, context_ids, context_token_ct, lf_ids, lf_token_ct, target_lf_ids, num_outputs): """ :param sf_ids: batch_size :param context_ids: batch_size, num_context_ids, 50 :param lf_ids: batch_size, max_output_size, max_lf_len, 50 :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids :param target_lf_ids: batch_size :param num_outputs: batch_size :return: LF predictions, LF ground truths """ device_str = 'cuda' if torch.cuda.is_available() else 'cpu' batch_size, num_context_ids, _ = context_ids.size() _, max_output_size, max_lf_len, _ = lf_ids.size() context_mask = np.ones([batch_size, num_context_ids], dtype=int) lf_mask = np.ones([batch_size, max_output_size, max_lf_len], dtype=int) for i in range(batch_size): context_mask[i, context_token_ct[i]:] = 0 for lf_idx in range(max_output_size): lf_mask[i, lf_idx, lf_token_ct[i, lf_idx]:] = 0 context_mask = torch.FloatTensor(context_mask).to(device_str) lf_mask = torch.FloatTensor(lf_mask).to(device_str) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(device_str) encoded_context = self.elmo(context_ids) encoded_context = (encoded_context * context_mask.unsqueeze(-1)).sum( axis=1) / context_token_ct.unsqueeze(-1) encoded_lfs = self.elmo( lf_ids.view(batch_size * max_output_size, max_lf_len, 50)).view(batch_size, max_output_size, max_lf_len, -1) encoded_lfs = (encoded_lfs * lf_mask.unsqueeze(-1) ).sum(2) / lf_token_ct.unsqueeze(-1) encoded_context_tiled = encoded_context.unsqueeze(1).repeat( 1, max_output_size, 1) score = nn.CosineSimilarity(dim=-1)(encoded_context_tiled, encoded_lfs) score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids
def forward(self, token_ids, sec_ids, cat_ids, context_ids, neg_context_ids, num_contexts): """ :param center_ids: batch_size :param context_ids: batch_size, 2 * context_window :param neg_context_ids: batch_size, 2 * context_window :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) center_ids = token_ids if self.multi_bsg: center_id_candidates = torch.cat([ token_ids.unsqueeze(0), sec_ids.unsqueeze(0), cat_ids.unsqueeze(0) ]) input_sample = torch.multinomial(self.input_weights, batch_size, replacement=True).to(self.device) center_ids = center_id_candidates.gather( 0, input_sample.unsqueeze(0)).squeeze(0) mu_q, sigma_q = self.encoder(center_ids, context_ids, mask, token_mask_p=self.mask_p) mu_p, sigma_p = self._compute_priors(token_ids) pos_mu_p, pos_sigma_p = self._compute_priors(context_ids) neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids) kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean() max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p, neg_sigma_p, mask).mean() return kl, max_margin
def forward(self, center_ids, context_ids, neg_context_ids, num_contexts): """ :param center_ids: batch_size :param context_ids: batch_size, 2 * context_window :param neg_context_ids: batch_size, 2 * context_window :param num_contexts: batch_size (how many context words for each center id - necessary for masking padding) :return: cost components: KL-Divergence (q(z|w,c) || p(z|w)) and max margin (reconstruction error) """ # Mask padded context ids batch_size, num_context_ids = context_ids.size() mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) mu_q, sigma_q = self.encoder(center_ids, context_ids, mask) mu_p, sigma_p = self._compute_priors(center_ids) pos_mu_p, pos_sigma_p = self._compute_priors(context_ids) neg_mu_p, neg_sigma_p = self._compute_priors(neg_context_ids) kl = compute_kl(mu_q, sigma_q, mu_p, sigma_p).mean() max_margin = self._max_margin(mu_q, sigma_q, pos_mu_p, pos_sigma_p, neg_mu_p, neg_sigma_p, mask).mean() return kl, max_margin
def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, lf_metadata_ids, num_outputs, num_contexts): """ :param sf_ids: LongTensor of batch_size :param context_ids: LongTensor of batch_size x 2 * context_window :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len :param lf_token_ct: batch_size, max_output_size - normalizer for lf_ids :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF :param num_outputs: list representing the number of target LFs for each row in batch. :return: """ batch_size, max_output_size, _ = lf_ids.size() # Next is to get prior representations for each LF in lf_ids lf_mu, lf_sigma = self._compute_priors(lf_ids) # Summarize LFs normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu_sum, lf_sigma_sum = lf_mu.sum(-2) / normalizer, lf_sigma.sum( -2) / normalizer combined_mu = [] combined_sigma = [] # Encode SFs in context sf_mu_tokens, sf_sigma_tokens = self.encode_context( sf_ids, context_ids, num_contexts) combined_mu.append(sf_mu_tokens) combined_sigma.append(sf_sigma_tokens) # For MBSGE ensemble method, we leverage section ids and note category ids if len(section_ids.nonzero()) > 0: sf_mu_sec, sf_sigma_sec = self.encode_context( section_ids, context_ids, num_contexts) combined_mu.append(sf_mu_sec) combined_sigma.append(sf_sigma_sec) if len(category_ids.nonzero()) > 0: sf_mu_cat, sf_sigma_cat = self.encode_context( category_ids, context_ids, num_contexts) combined_mu.append(sf_mu_cat) combined_sigma.append(sf_sigma_cat) combined_mu = torch.cat(list(map(lambda x: x.unsqueeze(1), combined_mu)), axis=1) combined_sigma = torch.cat(list( map(lambda x: x.unsqueeze(1), combined_sigma)), axis=1) sf_mu = combined_mu.mean(1) sf_sigma = combined_sigma.mean(1) # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat(1, max_output_size, 1).view( batch_size * max_output_size, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size, 1).view(batch_size * max_output_size, -1) lf_mu_flat = lf_mu_sum.view(batch_size * max_output_size, -1) lf_sigma_flat = lf_sigma_sum.view(batch_size * max_output_size, -1) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(self.device) kl = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size) score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids, None
center_word_tens = torch.LongTensor([token_vocab.get_id(center_word[0]) ]).to(device_str) header_ids = list(map(lambda x: metadata_vocab.get_id(x), headers)) compare_ids = list(map(lambda x: token_vocab.get_id(x), compare_words)) pad_context = torch.zeros([ 1, ]).long().to(device_str) context_ids = list(map(lambda x: token_vocab.get_id(x), context_words)) compare_tens = torch.LongTensor(compare_ids).to(device_str) context_tens = torch.LongTensor(context_ids).unsqueeze(0).to(device_str) mu_compare, sigma_compare = model.decoder( compare_tens, pad_context.repeat(len(compare_words))) mask = mask_2D(torch.Size([1, len(context_ids)]), [len(context_ids)]).to(device_str) print('Interpolation between word={} and headers={}'.format( center_word, ', '.join(headers))) for i, header_id in enumerate(header_ids): header_tens = torch.LongTensor([header_id]).to(device_str) print(headers[i]) for p in np.arange(0, 1.25, 0.25): rw = [p, 1.0 - p] rel_weights = torch.FloatTensor([rw]).to(device_str) with torch.no_grad(): mu_q, sigma_q, weights = model.encoder(center_word_tens, header_tens, context_tens, mask, center_mask_p=None, context_mask_p=None,
def forward(self, context_ids, center_ids, pos_ids, neg_ids, context_token_type_ids, num_contexts, context_mask, center_mask, pos_mask, neg_mask, num_metadata_samples=None): """ :param context_ids: batch_size x context_len (wp ids for BERT-style context word sequence with center word and metadata special token) :param center_ids: batch_size x num_context_ids (wp ids for center words and its metadata) :param pos_ids: batch_size x window_size * 2 x num_context_ids (wp ids for context words) :param neg_ids: batch_size x window_size * 2 x num_context_ids (wp ids for negatively sampled words and MC metadata sample) :param context_token_type_ids: batch_size x context_len (demarcating tokens from metadata in context_ids and MC metadata sample) :param num_contexts: batch_size (number of non-padded pos_ids / neg_ids in each row in batch) :param context_mask: batch_size x context_len (BERT attention mask for context_ids) :param center_mask: batch_size x num_context_ids :param pos_mask: batch_size x window_size * 2 x num_context_ids :param neg_mask: batch_size x window_size * 2 x num_context_ids :param num_metadata_samples: Number of MC samples approximating marginal distribution over metadata Affects where to how to segment BERT-style sequence into distinct token_type_ids segments :return: a tuple of 1-size tensors representing the hinge likelihood loss and the reconstruction kl-loss """ batch_size, num_context_ids, max_decoder_len = pos_ids.size() device = 'cuda' if torch.cuda.is_available() else 'cpu' mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(device) # Compute center words mu_center_q, sigma_center_q, _ = self.encoder( input_ids=context_ids, attention_mask=context_mask, token_type_ids=context_token_type_ids) mu_center_tiled_q = mu_center_q.unsqueeze(1).repeat(1, num_context_ids, 1) sigma_center_tiled_q = sigma_center_q.unsqueeze(1).repeat(1, num_context_ids, 1) mu_center_flat_q = mu_center_tiled_q.view(batch_size * num_context_ids, -1) sigma_center_flat_q = sigma_center_tiled_q.view(batch_size * num_context_ids, -1) # Compute decoded representations of (w, d), E(c), E(n) n = batch_size * num_context_ids pos_ids_flat, pos_mask_flat = pos_ids.view(n, -1), pos_mask.view(n, -1) neg_ids_flat, neg_mask_flat = neg_ids.view(n, -1), neg_mask.view(n, -1) joint_ids = torch.cat([center_ids, pos_ids_flat, neg_ids_flat], axis=0) joint_mask = torch.cat([center_mask, pos_mask_flat, neg_mask_flat], axis=0) center_sep_idx = 1 + 2 other_sep_idx = num_metadata_samples + 2 decoder_type_ids = torch.zeros([joint_ids.size()[0], max_decoder_len]).long().to(device) decoder_type_ids[:batch_size, -center_sep_idx:] = 1 decoder_type_ids[batch_size:, -other_sep_idx:] = 1 mu_joint, sigma_joint = self.decoder( input_ids=joint_ids, attention_mask=joint_mask, token_type_ids=decoder_type_ids) mu_center, sigma_center = mu_joint[:batch_size], sigma_joint[:batch_size] s = batch_size * (num_context_ids + 1) mu_pos_flat, sigma_pos_flat = mu_joint[batch_size:s], sigma_joint[batch_size:s] mu_neg_flat, sigma_neg_flat = mu_joint[s:], sigma_joint[s:] # Compute KL-divergence between center words and negative and reshape kl_pos_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_pos_flat, sigma_pos_flat) kl_neg_flat = compute_kl(mu_center_flat_q, sigma_center_flat_q, mu_neg_flat, sigma_neg_flat) kl_pos = kl_pos_flat.view(batch_size, num_context_ids) kl_neg = kl_neg_flat.view(batch_size, num_context_ids) hinge_loss = (kl_pos - kl_neg + 1.0).clamp_min_(0) hinge_loss.masked_fill_(mask, 0) hinge_loss = hinge_loss.sum(1) recon_loss = compute_kl(mu_center_q, sigma_center_q, mu_center, sigma_center).squeeze(-1) return hinge_loss.mean(), recon_loss.mean()
def forward(self, sf_ids, section_ids, category_ids, context_ids, lf_ids, target_lf_ids, lf_token_ct, lf_metadata_ids, lf_metadata_p, num_outputs, num_contexts): """ :param sf_ids: LongTensor of batch_size :param section_ids: LongTensor of batch_size :param category_ids: LongTensor of batch_size :param context_ids: LongTensor of batch_size x 2 * context_window :param lf_ids: LongTensor of batch_size x max_output_size x max_lf_len :param target_lf_ids: LongTensor of batch_size representing which index in lf_ids lies the target LF :param lf_token_ct: LongTensor of batch_size x max_output_size. N-gram count for each LF (used for masking) :param lf_metadata_ids: batch_size x max_output_size x max_num_metadata. Ids for every metadata LF appears in :param lf_metadata_p: batch_size x max_output_size x max_num_metadata Empirical probability for lf_metadata_ids ~ p(metadata|LF) :param num_outputs: list representing the number of target LFs for each row in batch. Used for masking to avoid returning invalid predictions. :param num_contexts: LongTensor of batch_size. The actual window size of the SF context. Many are shorter than target of 2 * context_window. :return: scores for each candidate LF, target_lf_ids, rel_weights (output of encoder gating function) """ batch_size, max_output_size, max_lf_ngram = lf_ids.size() _, num_context_ids = context_ids.size() # Compute SF contexts # Mask padded context ids mask_size = torch.Size([batch_size, num_context_ids]) mask = mask_2D(mask_size, num_contexts).to(self.device) sf_mu, sf_sigma, rel_weights = self.encoder(sf_ids, section_ids, context_ids, mask, center_mask_p=None, context_mask_p=None) num_metadata = lf_metadata_ids.size()[-1] # Tile SFs across each LF and flatten both SFs and LFs sf_mu_flat = sf_mu.unsqueeze(1).repeat( 1, max_output_size * num_metadata, 1).view(batch_size * max_output_size * num_metadata, -1) sf_sigma_flat = sf_sigma.unsqueeze(1).repeat( 1, max_output_size * num_metadata, 1).view(batch_size * max_output_size * num_metadata, -1) # Compute E[LF] normalizer = lf_token_ct.unsqueeze(-1).clamp_min(1.0) lf_mu, lf_sigma = self._compute_marginal(lf_ids, lf_metadata_ids, normalizer=normalizer) lf_mu_flat = lf_mu.view(batch_size * max_output_size * num_metadata, -1) lf_sigma_flat = lf_sigma.view( batch_size * max_output_size * num_metadata, 1) output_dim = torch.Size([batch_size, max_output_size]) output_mask = mask_2D(output_dim, num_outputs).to(self.device) kl_marginal = compute_kl(sf_mu_flat, sf_sigma_flat, lf_mu_flat, lf_sigma_flat).view(batch_size, max_output_size, num_metadata) kl = (kl_marginal * lf_metadata_p).sum(2) score = -kl score.masked_fill_(output_mask, float('-inf')) return score, target_lf_ids, rel_weights