def forward(self, idx_c, idx_q, y1s, y2s): """ idx_c: [batch_size, context_len] idx_q: [batch_size, question_len] y1s: [batch_size], start idxs of true answers y2s: [batch_size], end idxs of true answers """ mask_c = torch.zeros_like(idx_c) != idx_c mask_q = torch.zeros_like(idx_q) != idx_q embed_c = self.embedding(idx_c) # [batch_size, c_len, embed_dim] embed_q = self.embedding(idx_q) # [batch_size, q_len, embed_dim] embed_fc_c = self.embed_fc(embed_c) # [batch_size, c_len, hidden_dim] embed_fc_q = self.embed_fc(embed_q) # [batch_size, q_len, hidden_dim] highway_c = self.highway(embed_fc_c) # [batch_size, c_len, hidden_dim] highway_q = self.highway(embed_fc_q) # [batch_size, q_len, hidden_dim] enc_c, _ = self.enc(highway_c) # [batch_size, c_len, 2*hidden_dim] enc_c = self.dropout(enc_c) enc_q, _ = self.enc(highway_q) # [batch_size, q_len, 2*hidden_dim] enc_q = self.dropout(enc_q) # Attention flow layer G = self.attn_flow(enc_c, enc_q, mask_c, mask_q) # [batch_size, c_len, 8*hidden_dim] # Modeling layer M, _ = self.mod(G) # [batch_size, c_len, 2*hidden_dim] M_new, _ = self.enc_M(M) # [batch_size, c_len, 2*hidden_dim] # Output logits_1 = self.attn_fc1(G) + self.mod_fc1(M) # [batch_size, c_len, 1] logits_2 = self.attn_fc2(G) + self.mod_fc2( M_new) # [batch_size, c_len, 1] logits_1 = logits_1.squeeze(2) # [batch_size, c_len] logits_2 = logits_2.squeeze(2) # [batch_size, c_len] # log_p1 = masked_softmax(logits_1, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] # log_p2 = masked_softmax(logits_2, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p1 = masked_softmax(logits_1, mask_c, dim=1, log_softmax=False) # [batch_size, c_len] p2 = masked_softmax(logits_2, mask_c, dim=1, log_softmax=False) # [batch_size, c_len] ### Loss ### # Sometimes y1s/y2s are outside the model inputs (like -999), need to ignore these terms ignored_idx = p1.shape[1] y1s_clamp = torch.clamp( y1s, min=0, max=ignored_idx ) # limit value to [0, max_c_len]. '-999' converted to 0 y2s_clamp = torch.clamp(y2s, min=0, max=ignored_idx) loss_fn = nn.CrossEntropyLoss(ignore_index=ignored_idx) loss = (loss_fn(logits_1, y1s_clamp) + loss_fn(logits_2, y2s_clamp)) / 2 return loss, p1, p2
def forward(self, c, q, mask_c, mask_q): """ c: [batch_size, c_len, hidden_dim] q: [batch_size, q_len, hidden_dim] mask_c: [batch_size, c_len] mask_q: [batch_size, q_len] """ S = self.get_similarity_matrix(c, q) # [batch_size, c_len, q_len] # S1 = F.softmax(S, dim=2) # row-wise softmax. [batch_size, c_len, q_len] # S2 = F.softmax(S, dim=1) # column-wise softmax. [batch_size, c_len, q_len] mask_c = mask_c.unsqueeze(2) # [batch_size, c_len, 1] mask_q = mask_q.unsqueeze(1) # [batch_size, 1, q_len] S1 = masked_softmax(S, mask_q, dim=2) # row-wise softmax. [batch_size, c_len, q_len] S2 = masked_softmax(S, mask_c, dim=1) # column-wise softmax. [batch_size, c_len, q_len] # C2Q attention A = torch.bmm(S1, q) # [batch_size, c_len, hidden_dim] # Q2C attention S_temp = torch.bmm(S1, S2.transpose(1,2)) # [batch_size, c_len, c_len] B = torch.bmm(S_temp, c) # [batch_size, c_len, hidden_dim] # c: [batch_size, c_len, hidden_dim] # A: [batch_size, c_len, hidden_dim] # torch.mul(c, A): [batch_size, c_len, hidden_dim] # torch.mul(c, B): [batch_size, c_len, hidden_dim] G = torch.cat((c, A, torch.mul(c, A), torch.mul(c, B)), dim=2) # [batch_size, c_len, 4*hidden_dim] return G
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): """ Args: premise_batch: A batch of sequences of vectors representing the premises in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). premise_mask: A mask for the sequences in the premise batch, to ignore padding data in the sequences during the computation of the attention. hypothesis_batch: A batch of sequences of vectors representing the hypotheses in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). hypothesis_mask: A mask for the sequences in the hypotheses batch, to ignore padding data in the sequences during the computation of the attention. Returns: attended_premises: The sequences of attention vectors for the premises in the input batch. attended_hypotheses: The sequences of attention vectors for the hypotheses in the input batch. """ # Dot product between premises and hypotheses in each sequence of # the batch. similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1).contiguous()) # Softmax attention weights prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Weighted sums of the hypotheses for the the premises attention, # and vice-versa for the attention of the hypotheses. attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask) attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask) return attended_premises, attended_hypotheses
def forward(self, sent_a, sent_a_mask, sent_b, sent_b_mask): """ 输入: sent_a: [batch_size, seq_a_len, vec_dim] sent_a_mask: [batch_size, seq_a_len] sent_b: [batch_size, seq_b_len, vec_dim] sent_b_mask: [batch_size, seq_b_len] 输出: sent_a_att: [batch_size, seq_a_len, seq_b_len] sent_b_att: [batch_size, seq_b_len, seq_a_len] """ # similarity matrix similarity_matrix = torch.matmul( sent_a, sent_b.transpose(1, 2).contiguous()) # [batch_size, seq_a, seq_b] sent_a_b_attn = masked_softmax( similarity_matrix, sent_b_mask) # [batch_size, seq_a, seq_b] sent_b_a_attn = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), sent_a_mask) # [batch_size, seq_b, seq_a] sent_a_att = weighted_sum(sent_b, sent_a_b_attn, sent_a_mask) # [batch_size, seq_a, vec_dim] sent_b_att = weighted_sum(sent_a, sent_b_a_attn, sent_b_mask) # [batch_size, seq_b, vec_dim] return sent_a_att, sent_b_att
def forward(self, att, mod, mask): # Shapes: (batch_size, seq_len, 1) logits_1 = self.att_linear_1(att) + self.mod_linear_1(mod) mod_2 = self.rnn(mod, mask.sum(-1)) logits_2 = self.att_linear_2(att) + self.mod_linear_2(mod_2) # Shapes: (batch_size, seq_len) log_p1 = masked_softmax(logits_1.squeeze(), mask, log_softmax=True) log_p2 = masked_softmax(logits_2.squeeze(), mask, log_softmax=True) return log_p1, log_p2
def _correct_policy_distribution(scores_t, p_oracle): # make a renormalized distribution using the policy's probabilities over the correct set of actions # (and setting the probability of other actions to zero) scores_t = scores_t.clamp(-40, 40) # for numerical stability correct_actions_mask = p_oracle.gt(0).float() # gt for greater than, if gt then the location is 1 p_correct_policy = util.masked_softmax(scores_t, correct_actions_mask, dim=1) return p_correct_policy
def pred_ans(self, pth_path): # tokenizeation inputs = tokenizer(self.context, self.ques, truncation=True, max_length=512, return_tensors="pt") # inputs = tokenizer(context, question, truncation=True, max_length=512, do_lower_case=False, return_tensors="pt") # _batch_encode_plus() got an unexpected keyword argument 'do_lower_case' # Load checkpoin checkpoint = torch.load(pth_path, map_location=device) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict, strict=False) model.cpu() ## Run model model.eval() outputs = model(**inputs) p1s, p2s = outputs[0], outputs[1] # [1, clen+qlen] ## Get start/end idxs p1s = utils.masked_softmax(p1s, 1 - inputs['token_type_ids'], dim=1, log_softmax=True) # [1, clen+qlen] p2s = utils.masked_softmax(p2s, 1 - inputs['token_type_ids'], dim=1, log_softmax=True) # [1, clen+qlen] p1s, p2s = p1s.exp(), p2s.exp() s_idxs, e_idxs, top_probs = utils.get_ans_list_idx( p1s, p2s, max_len=5, num_answer=30) # [1, num_answer] ### Get ans candidates ### all_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) ans_cand = [] for j in range(30): # iter candidates ans_jth_tokens = all_tokens[s_idxs[0, j]:( e_idxs[0, j] + 1)] # Token list of one answer candidate ans_ids = tokenizer.convert_tokens_to_ids(ans_jth_tokens) ans = tokenizer.decode(ans_ids) ans_cand.append(ans) self.ans_cand = ans_cand
def forward(self, c, q, c_mask, q_mask): batch_size, c_len, _ = c.size() q_len = q.size(1) s = self.get_similarity_matrix(c, q) # (batch_size, c_len, q_len) c_mask = c_mask.view(batch_size, c_len, 1) # (batch_size, c_len, 1) q_mask = q_mask.view(batch_size, 1, q_len) # (batch_size, 1, q_len) s1 = masked_softmax(s, q_mask, dim=2) # (batch_size, c_len, q_len) s2 = masked_softmax(s, c_mask, dim=1) # (batch_size, c_len, q_len) # (bs, c_len, q_len) x (bs, q_len, hid_size) => (bs, c_len, hid_size) a = torch.bmm(s1, q) # (bs, c_len, c_len) x (bs, c_len, hid_size) => (bs, c_len, hid_size) b = torch.bmm(torch.bmm(s1, s2.transpose(1, 2)), c) x = torch.cat([c, a, c * a, c * b], dim=2) # (bs, c_len, 4 * hid_size) return x
def test_masked_softmax(): batched_input = torch.tensor([[1, 2, 3], [1, 1, 2], [3, 2, 1]]).float() batched_mask = torch.tensor([[1, 1, 0], [0, 1, 1], [1, 1, 1]]).float() batched_output = masked_softmax(batched_input, batched_mask, dim=1) # compare the result from masked_softmax with regular softmax with filtered values for input, mask, output in zip(batched_input, batched_mask, batched_output): assert output[output != 0].equal(F.softmax(input[mask == 1], dim=0))
def forward(self, encoded_state, candidate_action_reps, act_mask, eval_mode=None): ''' Score each admissible action and sample one encoded_state: batch x hidden encoded text description from textworld candidate_action_reps: batch x num_admisisble_actions x max_action_len list size of batch of numpy arrays. Each numpy array is num_admisisble_actions x max_action_len act_mask: Torch Tensor batch x num_actions binary mask 1 represents an admissible action, 0 represents a padded action ''' act_rep_tensor = torch.LongTensor(candidate_action_reps).to( self.device) batch, num_actions, _ = act_rep_tensor.size() #if using BERT #batch, num_actions, _ = encoded_act.size() encoded_act = self.command_encoder(act_rep_tensor) encoded_act = encoded_act.unsqueeze(1).view( batch, num_actions, -1) # batch x num_admisisble_actions x hidden #concat admissibble commands with state encoded_state = encoded_state.unsqueeze(1).expand( -1, num_actions, -1) #batch x num_actions x hidden #assert encoded_state.sum() == 0 encoding = torch.cat([encoded_state, encoded_act], dim=2) #batch x num_actions x 2*hidden #encoding = torch.cat([encoded_state, encoded_act ],dim=2) #zero out all actions that are fillers encoding = (act_mask.unsqueeze(-1).expand(-1, -1, encoding.shape[-1]) * encoding) #batch x num_actions x 2*hidden #score each action and sample logits = self.action_scorer(encoding).squeeze(-1) #batch x num_actions #masked softmax probs = masked_softmax(logits, act_mask, dim=1) #batch x num_actions if self.use_gumbel_softmax: #Gumbel-max for action sampling u = torch.rand(probs.size(), dtype=probs.dtype) idxs = torch.argmax(probs.cpu() - torch.log(-torch.log(u)), dim=-1) else: idxs = probs.multinomial(num_samples=1) #batch x 1 return idxs, logits
def ce_loss(pred, gt, ground_msk, s_msks): s_num = s_msks.sum(-1).clamp(min=1) pred = masked_softmax(pred, ground_msk) m_loss = F.binary_cross_entropy(pred, gt, reduction='none') #[B x sN x cN*hN] m_loss = m_loss * ground_msk # CE loss m_loss = m_loss.mean(-1) #[B x sN] m_loss = (m_loss * s_msks).sum(-1) / s_num # [B] m_loss = m_loss.mean() # [1] return m_loss
def forward(self, premises, premises_mask, hypotheses, hypotheses_mask): """ params premises: (S, N, H) hypotheses: (T, N, H) premises mask: (N, S) hypotheses maks: (N, T) return new_premises: (S, N, H) new_hypotheses: (T, N, H) """ premises = premises.transpose(0, 1) logging.debug(f"premises shape: {premises.shape}") # (N, S, H) hypotheses = hypotheses.transpose(0, 1) logging.debug(f"hypotheses shape: {hypotheses.shape}") # (N, T, H) attn_premises = torch.bmm(premises, hypotheses.transpose(1, 2)) # (N, S, T) attn_hypotheses = attn_premises.transpose(1, 2) # (N, T, S) attn_premises = masked_softmax(attn_premises, premises_mask, hypotheses_mask) # (N, S, T) attn_hypotheses = masked_softmax(attn_hypotheses, hypotheses_mask, premises_mask) # (N, T, S) logging.debug( f"weight: {attn_premises.shape}, tensor: {hypotheses.shape}") new_premises = weighted_sum(attn_premises, hypotheses) # (N, S, H) new_hypotheses = weighted_sum(attn_hypotheses, premises) # (N, T, H) new_premises = new_premises.transpose(0, 1) # (S, N, H) new_hypotheses = new_hypotheses.transpose(0, 1) # (T, N, H) return new_premises, new_hypotheses, attn_premises, attn_hypotheses
def __call__(self, scores, oracle, **kwargs): with torch.no_grad(): # do not calculate gradient if self.aux_end: token_scores, stop_probs = scores p_oracle = oracle.distribution() correct_actions_mask = p_oracle.gt(0).unsqueeze(1).float() stop_probs = correct_actions_mask[:, :, self. end_idx] # prevent invalid <end> token_scores = torch.clamp(token_scores, -40, 40) ps = util.masked_softmax(token_scores, correct_actions_mask, dim=2) samples = self.base_sampler((ps, stop_probs)) else: p_oracle = oracle.distribution() correct_actions_mask = p_oracle.gt(0).unsqueeze(1).float() scores = torch.clamp(scores, -40, 40) ps = util.masked_softmax(scores, correct_actions_mask, dim=2) samples = self.base_sampler(ps) return samples
def forward( self, ctx: torch.Tensor, query: torch.Tensor, ctx_mask: torch.Tensor, query_mask: torch.Tensor, ) -> torch.Tensor: """ ctx: (batch, ctx_seq_len, hidden_dim) query: (batch, query_seq_len, hidden_dim) ctx_mask: (batch, ctx_seq_len) query_mask: (batch, query_seq_len) output: (batch, ctx_seq_len, 4 * hidden_dim) """ ctx_seq_len = ctx.size(1) query_seq_len = query.size(1) # (batch, ctx_seq_len, query_seq_len) similarity = self.trilinear_for_attention(ctx, query) # (batch, ctx_seq_len, query_seq_len) s_ctx = masked_softmax(similarity, ctx_mask.unsqueeze(2).expand( -1, -1, query_seq_len), dim=1) # (batch, ctx_seq_len, query_seq_len) s_query = masked_softmax(similarity, query_mask.unsqueeze(1).expand( -1, ctx_seq_len, -1), dim=2) # (batch, ctx_seq_len, hidden_dim) P = torch.bmm(s_query, query) # (batch, ctx_seq_len, hidden_dim) Q = torch.bmm(torch.bmm(s_query, s_ctx.transpose(1, 2)), ctx) # (batch, ctx_seq_len, 4 * hidden_dim) return torch.cat([ctx, P, ctx * P, ctx * Q], dim=2)
def forward(self, user_review_inputs, item_review_inputs, user_review_masks, item_review_masks): """ Args: user_review_inputs: [bz, ur_num, in_features] item_review_inputs: [bz, ir_num, in_features] user_review_masks: [bz, ur_num] item_review_masks: [bz, ir_num] """ # aggregate item_review_inputs item_outputs, item_review_weights = self.item_aggregator( item_review_inputs, item_review_masks) # interaction score review_similarity_scores = self.bilinear(user_review_inputs, item_review_inputs) user_review_scores, _ = torch.max(review_similarity_scores, dim=2) #[bz, ur_num] user_review_weights = masked_softmax(user_review_scores, user_review_masks) user_outputs = attention_weighted_sum(user_review_weights, user_review_inputs) return user_outputs, item_outputs, user_review_weights, item_review_weights
c * W_cq, q.transpose(1, 2) ) # matmul([batch_size, c_len, hidden_dim], [batch_size, hidden_dim, q_len]) S = S_c + S_q + S_cq + b # [batch_size, c_len, q_len] print(S.shape) ### Forward idx_c = torch.rand((batch_size, c_len)) idx_q = torch.rand((batch_size, q_len)) mask_c = torch.zeros_like(idx_c) != idx_c mask_q = torch.zeros_like(idx_q) != idx_q mask_c = mask_c.unsqueeze(2) # [batch_size, c_len, 1] mask_q = mask_q.unsqueeze(1) # [batch_size, 1, q_len] S1 = masked_softmax(S, mask_q, dim=2) # row-wise softmax. [batch_size, c_len, q_len] S2 = masked_softmax(S, mask_c, dim=1) # column-wise softmax. [batch_size, c_len, q_len] # C2Q attention A = torch.bmm(S1, q) # [batch_size, c_len, hidden_dim] # Q2C attention S_temp = torch.bmm(S1, S2.transpose(1, 2)) # [batch_size, c_len, c_len] B = torch.bmm(S_temp, c) # [batch_size, c_len, hidden_dim] # c: [batch_size, c_len, hidden_dim] # A: [batch_size, c_len, hidden_dim] # torch.mul(c, A): [batch_size, c_len, hidden_dim] # torch.mul(c, B): [batch_size, c_len, hidden_dim] G = torch.cat((c, A, torch.mul(c, A), torch.mul(c, B)),
def forward(self, idx_c, idx_q, y1s, y2s): """ idx_c: [batch_size, c_len] idx_q: [batch_size, q_len] y1s: [batch_size], start idxs of true answers y2s: [batch_size], end idxs of true answers """ mask_c = torch.zeros_like(idx_c) != idx_c mask_q = torch.zeros_like(idx_q) != idx_q # 1. Inpur embed layer emb_c = self.embed_inp(idx_c) # [batch_size, c_len, embed_dim] emb_q = self.embed_inp(idx_q) # [batch_size, q_len, embed_dim] # connection between embed layer and embed encoder emb_c = self.embed_conv(emb_c) # [batch_size, c_len, embed_dim] emb_q = self.embed_conv(emb_q) # [batch_size, q_len, embed_dim] # 2. Embed encoder emb_enc_c = self.embed_enc_c(emb_c, mask_c) # [batch_size, c_len, hidden_dim] emb_enc_q = self.embed_enc_q(emb_q, mask_q) # [batch_size, q_len, hidden_din] # 3. CQ attention G = self.attn(emb_enc_c, emb_enc_q, mask_c, mask_q) # [batch_size, c_len, 4*hidden_dim] # Connection between attn and model encoder G = self.mod_conv(G) # [batch_size, c_len, hidden_dim] # 4. Model encoders M0 = G for enc in self.mod_enc: M0 = enc(M0, mask_c) # [batch_size, c_len, hidden_dim] M1 = M0 for enc in self.mod_enc: M1 = enc(M1, mask_c) # [batch_size, c_len, hidden_dim] M2 = M1 for enc in self.mod_enc: M2 = enc(M2, mask_c) # [batch_size, c_len, hidden_dim] # 5. Output x1 = torch.cat([M0, M1], dim=2) # [batch_size, c_len, 2*hidden_dim] x2 = torch.cat([M0, M2], dim=2) # [batch_size, c_len, 2*hidden_dim] logits1 = self.fc1(x1).squeeze(2) # [batch_size, c_len] logits2 = self.fc2(x2).squeeze(2) # [batch_size, c_len] # log_p1 = masked_softmax(logits1, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] # log_p2 = masked_softmax(logits2, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p1 = masked_softmax(logits1, mask_c, dim=1, log_softmax=False) # [batch_size, c_len] p2 = masked_softmax(logits2, mask_c, dim=1, log_softmax=False) # [batch_size, c_len] ### Loss ### # Sometimes y1s/y2s are outside the model inputs (like -999), need to ignore these terms ignored_idx = p1.shape[1] y1s_clamp = torch.clamp(y1s, min=0, max=ignored_idx) # limit value to [0, max_c_len]. '-999' converted to 0 y2s_clamp = torch.clamp(y2s, min=0, max=ignored_idx) loss_fn = nn.CrossEntropyLoss(ignore_index=ignored_idx) loss = (loss_fn(logits1, y1s_clamp) + loss_fn(logits2, y2s_clamp)) / 2 return loss, p1, p2
def forward(self, embs, toks, msks, s_msks, c_msks, dep_root_msk, imgs, i_msks, poses, pose_msks, i3d_rgb=None, face=None, face_msks=None, bbox_meta=None): ''' params: embs: [B x 1 x L] toks: [B x 1 x L] msks: [B x 1 x L] s_msks: [B x sN x L] c_msks: [B x sN x cN] dep_root_msk: [B x sN x L] imgs: [B x cN x hN x 3 x 224 x 224] i_msks: [B x cN x hN] poses: [B x cN x hN x 17 x 2] pose_msks: [B x cN x hN x 17] i3d_rgb: [B x cN x hN x 1024] face: [B x cN x hN x 512] face_msks: [B x cN x hN] bbox_meta_5: [B x cN x hN x 3] ''' B, sN, L = s_msks.shape cN, hN = imgs.size(1), imgs.size(2) # Text embedding embs, toks, msks = embs.view(-1, L), toks.view(-1, L), msks.view( -1, L) # [B*sN x L] bert_x, _ = self.bert(embs, toks, msks, output_all_encoded_layers=False) # [B x L x 768] bert_x = gelu(bert_x).view(B, 1, L, 768) if 'someone' in self.t_feats: ts_x = bert_x * s_msks.view(B, sN, L, 1) # [B x sN x L x 768] ts_x = torch.sum(ts_x, 2) # [B x sN x 768] if 'action' in self.t_feats: ta_x = bert_x * dep_root_msk.view(B, sN, L, 1) # [B x sN x L x 768] ta_x = torch.sum(ta_x, 2) # [B x sN x 768] tce_x = torch.cat([ta_x, ts_x], -1) # [B x sN x sdim] # Text Projection s_a_x = self.s_a_proj(tce_x) # [B x sN x H] s_s_x = self.s_s_proj(tce_x) # [B x sN x H] # Auxiliary Gender Classifier if self.use_gender: g_x = self.gender_fc(ts_x) # [B x sN x 1] self.gender_result = g_x.squeeze(-1) # [B x sN] # Visual embedding i_x = torch.zeros((B * cN * hN, 0), device=bbox_meta.device) if 'img' in self.v_feats: imgs = imgs.view(-1, 3, 224, 224) img_x = self.act_conv(imgs) # [B*cN*hN x 2048] i_x = torch.cat((i_x, img_x), -1) #[B*cN*hN x (2048 + 4)] if "i3d_rgb" in self.v_feats: i3d_x = i3d_rgb.view(-1, 1024) # [B*cN*hN x 1024] i_x = torch.cat((i_x, i3d_x), -1) # [B*cN*hN x num_ftrs] if 'face' in self.v_feats: face_x = face.view(-1, 512) # [B*cN*hN x 512] # Image projection i_a_x = self.i_a_proj(i_x) # [B*cN*hN x H] i_a_x = i_a_x.view((B, cN * hN, -1)) # [B x cN*hN x H] i_s_x = self.i_s_proj(face_x) # [B*cN*hN x H] i_s_x = i_s_x.view((B, cN * hN, -1)) # [B x cN*hN x H] if 'meta' in self.v_feats: meta_x = self.meta_proj(bbox_meta.view(-1, 4)) # [B*cN*hN x 50] # Character Grounding s_a_x = s_a_x.unsqueeze(2).repeat(1, 1, cN * hN, 1).view(-1, self.hidden_dim) i_a_x = i_a_x.unsqueeze(1).repeat(1, sN, 1, 1).view(-1, self.hidden_dim) s_s_x = s_s_x.unsqueeze(2).repeat(1, 1, cN * hN, 1).view(-1, self.hidden_dim) i_s_x = i_s_x.unsqueeze(1).repeat(1, sN, 1, 1).view(-1, self.hidden_dim) if 'meta' in self.v_feats: meta_x = meta_x.view(B, 1, cN * hN, 50).repeat(1, sN, 1, 1) # B x sN x cN*hN x 50 f_a_x = s_a_x * i_a_x f_a_x = f_a_x.view(B, sN, cN * hN, self.hidden_dim) # [B x sN x cN*hN x H] f_s_x = s_s_x * i_s_x f_s_x = f_s_x.view(B, sN, cN * hN, self.hidden_dim) # [B x sN x cN*hN x H] f_t_x = torch.cat((meta_x, f_a_x, f_s_x), -1) # [B x sN x cN*hN x fH] f_x = self.fuse_fc(f_t_x) # [B x sN x cN*hN x 1] f_x = f_x.squeeze(-1) # [B x sN x cN*hN] # Masking pred_ground = f_x * i_msks.view(B, 1, cN * hN) * s_msks.sum( -1, keepdim=True) # Character Re-Identification pred_chreid = 0 to_divide = 1 self.text_mat, self.vid_mat, self.vgv_mat = None, None, None # Text Re-Id if 'text' in self.char_reid: text_reid_x = ts_x.unsqueeze(1) * ts_x.unsqueeze( 2) # [B x sN x sN x 768] text_reid_x = self.text_reid_fc(text_reid_x) text_mat = text_reid_x.squeeze(-1) # [B x sN x sN] self.text_mat = text_mat pred_chreid += self.text_mat to_divide += 1 # Person identity representation if 'visual' in self.char_reid: vid_mat = self.vid_model(imgs.view(B, -1, 3, 224, 224), i_msks.view(B, -1), poses.view(B, -1, 17, 2), pose_msks.view(B, -1, 17), face.view(B, -1, 512), face_msks.view(B, -1)) # B x cN*hN x cN*hN self.vid_mat = vid_mat p_msks = c_msks.view(B, sN, cN, 1).repeat(1, 1, 1, hN).view(B, sN, cN * hN) i_mm = i_msks.view(B, 1, cN * hN) # [B x 1 x cN*hN] p_msks = p_msks * i_mm # [B x sN x cN*hN] (1 if exist, 0 if not) p_att = masked_softmax( f_x * 5, p_msks) #[B x sN x cN*hN], sharpen f_x little. vgv_mat = torch.matmul(p_att, vid_mat) vgv_mat = vgv_mat * ( 1 - p_msks) - 0.1 * p_msks # Mask for someone in same clip vgv_mat = torch.matmul(vgv_mat, p_att.transpose(1, 2)) # [B x sN x sN] vgv_mat = torch.sigmoid(vgv_mat) # [B x sN x sN] 0 ~ 1 self.vgv_mat = vgv_mat pred_chreid = (pred_chreid + vgv_mat) / to_divide return pred_ground, pred_chreid
def valid_fn_list(model, data_loader, tokenizer, device, num_answer, ans_thres): scores = {'loss': 0, 'f1': 0, 'prec': 0, 'rec': 0} len_iter = len(data_loader) n_samples = 0 model.eval() epoch_dic = [] with torch.no_grad(): with tqdm(total=len_iter) as progress_bar: for _, batch in enumerate(data_loader): input_ids = batch['input_ids'].to( device) # [batch_size, c_len] attn_mask = batch['attention_mask'].to( device) # [batch_size, c_len] y1s = batch['token_starts'].to(device) # [batch_size] y2s = batch['token_ends'].to(device) # [batch_size] outputs = model(input_ids, attention_mask=attn_mask, start_positions=y1s, end_positions=y2s) loss = outputs[0] p1s, p2s = outputs[1], outputs[2] # [batch_size, c_len] scores['loss'] += loss.item() # Get start/end idxs mask_c = attn_mask != torch.zeros_like(attn_mask) p1s = utils.masked_softmax( p1s, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p2s = utils.masked_softmax( p2s, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p1s, p2s = p1s.exp(), p2s.exp() # Each record has [num_answer] candidates s_idxs, e_idxs, top_probs = utils.get_ans_list_idx( p1s, p2s, num_answer=num_answer) # [batch_size, num_answer] for i in range(p1s.shape[0]): all_tokens = tokenizer.convert_ids_to_tokens( batch['input_ids'][i]) if (y1s[i] >= 0) and ( y2s[i] >= 0): # When the answer passage not truncated # True answer tokens ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)] answer_ids = tokenizer.convert_tokens_to_ids( ans_piece_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_true = answer.lower().split() ## Predicted answer tokens # Convert pred idxs to tokens for each record record_preds_cand = [] for j in range(num_answer): # iter candidates ans_jth_tokens = all_tokens[s_idxs[i, j]:( e_idxs[i, j] + 1)] # Token list of one answer candidate answer_ids = tokenizer.convert_tokens_to_ids( ans_jth_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_pred = answer.lower().split() record_preds_cand.append( ans_tokens_pred ) # Get list of token list of [num_answer] answer candidates record = { 'pid': batch['pids'][i], # 1234 'trues': ans_tokens_true, # ["ab", "ef"] 'preds': record_preds_cand, # [["af", "c"], ["b"], ..., ["fg", "ab"]] 'probs': top_probs[i].tolist() } # [0.3, 0.6, ..., 0.1] epoch_dic.append(record) n_samples += 1 progress_bar.update(1) # update progress bar # Group epoch_dic by pid key = lambda d: d['pid'] epoch_dic.sort(key=key) epoch_dic_gp = [] for key, group in itertools.groupby(epoch_dic, key=key): trues, preds, probs = [], [], [] for g in group: trues.append(g['trues']) # [["ab", "ef"], ["cd"]] preds = preds + g[ 'preds'] # g['preds']: [["ab","ef"], ["cd"], ["fg","b"], ["b","c"], ["gh"]] probs = probs + g['probs'] # g['probs']: [0.3, 0.4, 0.2, 0.7, 0.1] # Sort candidates by probs probs, preds = (list(t) for t in zip( *sorted(zip(probs, preds), key=lambda x: x[0], reverse=True))) # Keep candidates with probs > thres preds = [preds[i] for i in range(len(probs)) if probs[i] > ans_thres] # Remove duplicate candidates preds_ndup = [] for p in preds: if p not in preds_ndup: preds_ndup.append(p) epoch_dic_gp.append({'pid': key, 'trues': trues, 'preds': preds_ndup}) for ep in epoch_dic_gp: if len(ep['preds']) > 0: f1, pre, rec = utils.metric_ave_fpr(ep['preds'], ep['trues']) scores['f1'] += f1 scores['prec'] += pre scores['rec'] += rec scores['loss'] = scores['loss'] / len_iter scores['f1'] = scores['f1'] / n_samples scores['prec'] = scores['prec'] / n_samples scores['rec'] = scores['rec'] / n_samples return scores
def valid_fn(model, data_loader, tokenizer, device): scores = {'loss': 0, 'em': 0, 'f1': 0, 'prec': 0, 'rec': 0} len_iter = len(data_loader) n_samples = 0 model.eval() with torch.no_grad(): with tqdm(total=len_iter) as progress_bar: for j, batch in enumerate(data_loader): input_ids = batch['input_ids'].to( device) # [batch_size, c_len] attn_mask = batch['attention_mask'].to( device) # [batch_size, c_len] y1s = batch['token_starts'].to(device) # [batch_size] y2s = batch['token_ends'].to(device) # [batch_size] outputs = model(input_ids, attention_mask=attn_mask, start_positions=y1s, end_positions=y2s) loss = outputs[0] p1s, p2s = outputs[1], outputs[2] # [batch_size, c_len] # Get start/end idxs mask_c = attn_mask != torch.zeros_like(attn_mask) p1s = utils.masked_softmax( p1s, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p2s = utils.masked_softmax( p2s, mask_c, dim=1, log_softmax=True) # [batch_size, c_len] p1s, p2s = p1s.exp(), p2s.exp() s_idxs, e_idxs = utils.get_ans_idx(p1s, p2s) # [batch_size] ans_tokens_pred, ans_tokens_true = [], [] for i in range(p1s.shape[0]): all_tokens = tokenizer.convert_ids_to_tokens( batch['input_ids'][i]) if (y1s[i] >= 0) and ( y2s[i] >= 0): # When the answer passage not truncated # Predicted answer tokens ans_piece_tokens = all_tokens[s_idxs[i]:(e_idxs[i] + 1)] answer_ids = tokenizer.convert_tokens_to_ids( ans_piece_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_pred = answer.lower().split() # True answer tokens ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)] answer_ids = tokenizer.convert_tokens_to_ids( ans_piece_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_true = answer.lower().split() scores['em'] += utils.metric_em( ans_tokens_pred, ans_tokens_true) f1, prec, rec = utils.metric_f1_pr( ans_tokens_pred, ans_tokens_true) scores['f1'] += f1 scores['prec'] += prec scores['rec'] += rec n_samples += 1 scores['loss'] += loss.item() progress_bar.update(1) # update progress bar scores['loss'] = scores['loss'] / len_iter scores['em'] = scores['em'] / n_samples scores['f1'] = scores['f1'] / n_samples scores['prec'] = scores['prec'] / n_samples scores['rec'] = scores['rec'] / n_samples return scores
def train_fn(model, data_loader, optimizer, scheduler, tokenizer, clip, accum_step, device): scores = {'loss': 0, 'em': 0, 'f1': 0, 'prec': 0, 'rec': 0} len_iter = len(data_loader) n_samples = 0 model.train() optimizer.zero_grad() with tqdm(total=len_iter) as progress_bar: for j, batch in enumerate(data_loader): input_ids = batch['input_ids'].to(device) # [batch_size, c_len] attn_mask = batch['attention_mask'].to( device) # [batch_size, c_len] y1s = batch['token_starts'].to(device) # [batch_size] y2s = batch['token_ends'].to(device) # [batch_size] outputs = model(input_ids, attention_mask=attn_mask, start_positions=y1s, end_positions=y2s) loss = outputs[0] p1s, p2s = outputs[1], outputs[2] # [batch_size, c_len] scores['loss'] += loss.item() loss = loss / accum_step # loss gradients are accumulated by loss.backward() so we need to ave accumulated loss gradients loss.backward() nn.utils.clip_grad_norm_(model.parameters(), clip) # prevent exploding gradients # Gradient accumulation if (j + 1) % accum_step == 0: optimizer.step() scheduler.step() optimizer.zero_grad() # Get start/end idxs p1s = utils.masked_softmax(p1s, attn_mask, dim=1, log_softmax=True) # [batch_size, c_len] p2s = utils.masked_softmax(p2s, attn_mask, dim=1, log_softmax=True) # [batch_size, c_len] p1s, p2s = p1s.exp(), p2s.exp() s_idxs, e_idxs = utils.get_ans_idx(p1s, p2s) # [batch_size] # ans_tokens_pred, ans_tokens_true = [], [] for i in range(p1s.shape[0]): all_tokens = tokenizer.convert_ids_to_tokens( batch['input_ids'][i]) if (y1s[i] >= 0) and ( y2s[i] >= 0): # When the answer passage not truncated # Predicted answer tokens ans_piece_tokens = all_tokens[s_idxs[i]:(e_idxs[i] + 1)] answer_ids = tokenizer.convert_tokens_to_ids( ans_piece_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_pred = answer.lower().split() # True answer tokens ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)] answer_ids = tokenizer.convert_tokens_to_ids( ans_piece_tokens) answer = tokenizer.decode(answer_ids) ans_tokens_true = answer.lower().split() scores['em'] += utils.metric_em(ans_tokens_pred, ans_tokens_true) f1, prec, rec = utils.metric_f1_pr(ans_tokens_pred, ans_tokens_true) scores['f1'] += f1 scores['prec'] += prec scores['rec'] += rec n_samples += 1 progress_bar.update(1) # update progress bar scores['loss'] = scores['loss'] / len_iter scores['em'] = scores['em'] / n_samples scores['f1'] = scores['f1'] / n_samples scores['prec'] = scores['prec'] / n_samples scores['rec'] = scores['rec'] / n_samples return scores
def forward(self, sentences1, sentences2): """ sentences1 [batch, max_len] sentences2 [batch, max_len] """ # get mask sentences1_mask = (sentences1 != self.padding_idx).long().to( self.device) # [batch_size, max_len] sentences2_mask = (sentences2 != self.padding_idx).long().to( self.device) # [batch_size, max_len] # input encoding sentences1_emb = self.emb(sentences1) # [batch_size, max_len, dim] sentences2_emb = self.emb(sentences2) # [batch_size, max_len, dim] sentences1_len = torch.sum(sentences1_mask, dim=-1).view(-1) # [batch_size] sentences2_len = torch.sum(sentences2_mask, dim=-1).view(-1) # [batch_size] #encoder s1_encoded = self.encoder_layer( sentences1_emb, sentences1_len) # [batch_size, max_len_q1, dim] s2_encoded = self.encoder_layer( sentences2_emb, sentences2_len) # [batch_size, max_len_q2, dim] # local inference # e_ij = a_i^Tb_j (11) similarity_matrix = s1_encoded.bmm( s2_encoded.transpose( 2, 1).contiguous()) # [batch_size, max_len_q1, max_len_q2] s1_s2_atten = masked_softmax( similarity_matrix, sentences2_mask) # [batch_size, max_len_q1, max_len_q2] s2_s1_atten = masked_softmax( similarity_matrix.transpose(2, 1).contiguous(), sentences1_mask) # [batch_size, max_len_q2, max_len_q1] # eij * bj a_hat = weighted_sum(s1_encoded, s1_s2_atten, sentences1_mask) # [batch_size, max_len_q1, dim] b_hat = weighted_sum(s2_encoded, s2_s1_atten, sentences2_mask) # [batch_size, max_len_q2, dim] # Enhancement of local inference information # ma = [a¯; a~; a¯ − a~; a¯ a~]; # mb = [b¯; b~; b¯ − b~; b¯ b~] m_a = torch.cat( [s1_encoded, a_hat, s1_encoded - a_hat, s1_encoded * a_hat], dim=-1) # [batch_size, max_len_q1, 4 * dim] m_b = torch.cat( [s2_encoded, b_hat, s2_encoded - b_hat, s2_encoded * b_hat], dim=-1) # 3.3 Inference Composition s1_projected = self.projection(m_a) # [batch_size, max_len_q1, dim] s2_projected = self.projection(m_b) # [batch_size, max_len_q2, dim] v_a = self.composition_layer( s1_projected, sentences1_len) # [batch_size, max_len_q1, dim] v_b = self.composition_layer( s2_projected, sentences2_len) # [batch_size, max_len_q2, dim] v_a_avg = torch.sum(v_a * sentences1_mask.unsqueeze(1).transpose(2, 1), dim=1) \ / torch.sum(sentences1_mask, dim=1, keepdim = True) # q1_mask batch_size, 1, max_len_q1 v_b_avg = torch.sum(v_b * sentences2_mask.unsqueeze(1).transpose(2, 1), dim=1) \ / torch.sum(sentences2_mask, dim=1, keepdim = True) v_a_max, _ = replace_masked(v_a, sentences1_mask, -1e7).max(dim=1) # [batch_size, dim] v_b_max, _ = replace_masked(v_b, sentences2_mask, -1e7).max(dim=1) v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # [batch_size, dim * 4] # predict logits = self.predict_fc(v) return logits
def forward(self, user_sent_inputs, item_sent_inputs, user_sent_masks, item_sent_masks, debug=False): """ Args: user_sent_inputs: [bz, ur_num, us_num, in_features] item_sent_inputs: [bz, ir_num, is_num, in_features] user_sent_masks: [bz, ur_num, us_num] item_sent_masks: [bz, ir_num, is_num] returns: user_review_outputs: [bz, ur_num, out_features] item_review_outputs: [bz, ir_num, out_features] """ # aggregate item_sent_inputs bz, ir_num, is_num, in_features = list(item_sent_inputs.size()) item_sent_inputs = item_sent_inputs.view(bz * ir_num, is_num, in_features) item_sent_masks = item_sent_masks.view(bz * ir_num, is_num) item_review_outputs, item_sent_weights, item_all_sent_weights = self.item_aggregator( item_sent_inputs, item_sent_masks) # interaction score bz, ur_num, us_num, in_features = list(user_sent_inputs.size()) item_all_sent_inputs = item_sent_inputs.view(bz, ir_num * is_num, in_features) item_all_sent_inputs = item_all_sent_inputs * item_all_sent_weights.unsqueeze( -1) chunks_of_user_sent_inputs = torch.chunk(user_sent_inputs, dim=1, chunks=ur_num) chunks_of_user_sent_masks = torch.chunk(user_sent_masks, dim=1, chunks=ur_num) user_review_outputs = [] user_sent_weights = [] for user_sent_inputs_pr, user_sent_masks_pr in zip( chunks_of_user_sent_inputs, chunks_of_user_sent_masks): user_sent_inputs_pr = user_sent_inputs_pr.squeeze( 1) #[bz, us_num, in_features] user_sent_masks_pr = user_sent_masks_pr.squeeze(1) #[bz, us_num] ui_similarity_score_pr = self.bilinear( user_sent_inputs_pr, item_all_sent_inputs) #[bz, us_num, ir_num*is_num] user_sent_scores_pr, _ = torch.max(ui_similarity_score_pr, dim=2) #[bz, us_num] user_sent_weights_pr = masked_softmax(user_sent_scores_pr, user_sent_masks_pr) user_review_output_pr = attention_weighted_sum( user_sent_weights_pr, user_sent_inputs_pr) user_review_output_pr = user_review_output_pr.unsqueeze( 1) # NOTE: dirty implementation, [bz, 1, out_features] user_sent_weights_pr = user_sent_weights_pr.unsqueeze( 1) # [bz, 1, us_num] user_review_outputs.append( user_review_output_pr) # list of [bz, 1, in_features] user_sent_weights.append(user_sent_weights_pr) #print("user review outputs element", user_review_outputs[0].shape) user_review_outputs = torch.cat(user_review_outputs, dim=1) user_sent_weights = torch.cat(user_sent_weights, dim=1) if debug: print("UnbalancedCoAttentionAggregator: ") print("user_review_outputs", user_review_outputs) print("item_review_outputs", item_review_outputs) print("user_sent_weights", user_sent_weights) print("item_sent_weights", item_sent_weights) print("item_all_sent_weights", item_all_sent_weights) return user_review_outputs, item_review_outputs, user_sent_weights, item_sent_weights, item_all_sent_weights