Exemple #1
0
    def forward(self, idx_c, idx_q, y1s, y2s):
        """
            idx_c: [batch_size, context_len]
            idx_q: [batch_size, question_len]  
            y1s: [batch_size], start idxs of true answers
            y2s: [batch_size], end idxs of true answers
            
        """
        mask_c = torch.zeros_like(idx_c) != idx_c
        mask_q = torch.zeros_like(idx_q) != idx_q

        embed_c = self.embedding(idx_c)  # [batch_size, c_len, embed_dim]
        embed_q = self.embedding(idx_q)  # [batch_size, q_len, embed_dim]

        embed_fc_c = self.embed_fc(embed_c)  # [batch_size, c_len, hidden_dim]
        embed_fc_q = self.embed_fc(embed_q)  # [batch_size, q_len, hidden_dim]

        highway_c = self.highway(embed_fc_c)  # [batch_size, c_len, hidden_dim]
        highway_q = self.highway(embed_fc_q)  # [batch_size, q_len, hidden_dim]

        enc_c, _ = self.enc(highway_c)  # [batch_size, c_len, 2*hidden_dim]
        enc_c = self.dropout(enc_c)

        enc_q, _ = self.enc(highway_q)  # [batch_size, q_len, 2*hidden_dim]
        enc_q = self.dropout(enc_q)

        # Attention flow layer
        G = self.attn_flow(enc_c, enc_q, mask_c,
                           mask_q)  # [batch_size, c_len, 8*hidden_dim]

        # Modeling layer
        M, _ = self.mod(G)  # [batch_size, c_len, 2*hidden_dim]
        M_new, _ = self.enc_M(M)  # [batch_size, c_len, 2*hidden_dim]

        # Output
        logits_1 = self.attn_fc1(G) + self.mod_fc1(M)  # [batch_size, c_len, 1]
        logits_2 = self.attn_fc2(G) + self.mod_fc2(
            M_new)  # [batch_size, c_len, 1]
        logits_1 = logits_1.squeeze(2)  # [batch_size, c_len]
        logits_2 = logits_2.squeeze(2)  # [batch_size, c_len]

        # log_p1 = masked_softmax(logits_1, mask_c, dim=1, log_softmax=True)  # [batch_size, c_len]
        # log_p2 = masked_softmax(logits_2, mask_c, dim=1, log_softmax=True)  # [batch_size, c_len]
        p1 = masked_softmax(logits_1, mask_c, dim=1,
                            log_softmax=False)  # [batch_size, c_len]
        p2 = masked_softmax(logits_2, mask_c, dim=1,
                            log_softmax=False)  # [batch_size, c_len]

        ### Loss ###
        # Sometimes y1s/y2s are outside the model inputs (like -999), need to ignore these terms
        ignored_idx = p1.shape[1]
        y1s_clamp = torch.clamp(
            y1s, min=0, max=ignored_idx
        )  # limit value to [0, max_c_len]. '-999' converted to 0
        y2s_clamp = torch.clamp(y2s, min=0, max=ignored_idx)
        loss_fn = nn.CrossEntropyLoss(ignore_index=ignored_idx)
        loss = (loss_fn(logits_1, y1s_clamp) +
                loss_fn(logits_2, y2s_clamp)) / 2

        return loss, p1, p2
Exemple #2
0
 def forward(self, c, q, mask_c, mask_q):
     """
         c: [batch_size, c_len, hidden_dim]
         q: [batch_size, q_len, hidden_dim]  
         mask_c: [batch_size, c_len]
         mask_q: [batch_size, q_len]  
     """ 
     
     S = self.get_similarity_matrix(c, q)  # [batch_size, c_len, q_len]
         
     # S1 = F.softmax(S, dim=2)  # row-wise softmax. [batch_size, c_len, q_len]
     # S2 = F.softmax(S, dim=1)  # column-wise softmax. [batch_size, c_len, q_len]
     
     mask_c = mask_c.unsqueeze(2)  # [batch_size, c_len, 1]
     mask_q = mask_q.unsqueeze(1)  # [batch_size, 1, q_len]
     S1 = masked_softmax(S, mask_q, dim=2)  # row-wise softmax. [batch_size, c_len, q_len]
     S2 = masked_softmax(S, mask_c, dim=1)  # column-wise softmax. [batch_size, c_len, q_len]
     
     # C2Q attention
     A = torch.bmm(S1, q)  # [batch_size, c_len, hidden_dim]
     
     # Q2C attention
     S_temp = torch.bmm(S1, S2.transpose(1,2))  # [batch_size, c_len, c_len]
     B = torch.bmm(S_temp, c)    # [batch_size, c_len, hidden_dim]
           
     # c: [batch_size, c_len, hidden_dim]
     # A: [batch_size, c_len, hidden_dim]
     # torch.mul(c, A): [batch_size, c_len, hidden_dim]
     # torch.mul(c, B): [batch_size, c_len, hidden_dim]
     G = torch.cat((c, A, torch.mul(c, A), torch.mul(c, B)), dim=2)  # [batch_size, c_len, 4*hidden_dim]
     
     return G
Exemple #3
0
 def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
     """
     Args:
         premise_batch: A batch of sequences of vectors representing the
             premises in some NLI task. The batch is assumed to have the
             size (batch, sequences, vector_dim).
         premise_mask: A mask for the sequences in the premise batch, to
             ignore padding data in the sequences during the computation of
             the attention.
         hypothesis_batch: A batch of sequences of vectors representing the
             hypotheses in some NLI task. The batch is assumed to have the
             size (batch, sequences, vector_dim).
         hypothesis_mask: A mask for the sequences in the hypotheses batch,
             to ignore padding data in the sequences during the computation
             of the attention.
     Returns:
         attended_premises: The sequences of attention vectors for the
             premises in the input batch.
         attended_hypotheses: The sequences of attention vectors for the
             hypotheses in the input batch.
     """
     # Dot product between premises and hypotheses in each sequence of
     # the batch.
     similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1).contiguous())
     # Softmax attention weights
     prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
     hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
     # Weighted sums of the hypotheses for the the premises attention,
     # and vice-versa for the attention of the hypotheses.
     attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask)
     attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask)
     return attended_premises, attended_hypotheses
 def forward(self, sent_a, sent_a_mask, sent_b, sent_b_mask):
     """
     输入:
     sent_a: [batch_size, seq_a_len, vec_dim]
     sent_a_mask: [batch_size, seq_a_len]
     sent_b: [batch_size, seq_b_len, vec_dim]
     sent_b_mask: [batch_size, seq_b_len]
     输出:
     sent_a_att: [batch_size, seq_a_len, seq_b_len]
     sent_b_att: [batch_size, seq_b_len, seq_a_len]
     """
     # similarity matrix
     similarity_matrix = torch.matmul(
         sent_a,
         sent_b.transpose(1, 2).contiguous())  # [batch_size, seq_a, seq_b]
     sent_a_b_attn = masked_softmax(
         similarity_matrix, sent_b_mask)  # [batch_size, seq_a, seq_b]
     sent_b_a_attn = masked_softmax(
         similarity_matrix.transpose(1, 2).contiguous(),
         sent_a_mask)  # [batch_size, seq_b, seq_a]
     sent_a_att = weighted_sum(sent_b, sent_a_b_attn,
                               sent_a_mask)  # [batch_size, seq_a, vec_dim]
     sent_b_att = weighted_sum(sent_a, sent_b_a_attn,
                               sent_b_mask)  # [batch_size, seq_b, vec_dim]
     return sent_a_att, sent_b_att
Exemple #5
0
    def forward(self, att, mod, mask):
        # Shapes: (batch_size, seq_len, 1)
        logits_1 = self.att_linear_1(att) + self.mod_linear_1(mod)
        mod_2 = self.rnn(mod, mask.sum(-1))
        logits_2 = self.att_linear_2(att) + self.mod_linear_2(mod_2)

        # Shapes: (batch_size, seq_len)
        log_p1 = masked_softmax(logits_1.squeeze(), mask, log_softmax=True)
        log_p2 = masked_softmax(logits_2.squeeze(), mask, log_softmax=True)

        return log_p1, log_p2
Exemple #6
0
def _correct_policy_distribution(scores_t, p_oracle):
    # make a renormalized distribution using the policy's probabilities over the correct set of actions
    # (and setting the probability of other actions to zero)
    scores_t = scores_t.clamp(-40, 40)  # for numerical stability
    correct_actions_mask = p_oracle.gt(0).float()   # gt for greater than, if gt then the location is 1
    p_correct_policy = util.masked_softmax(scores_t, correct_actions_mask, dim=1)
    return p_correct_policy
Exemple #7
0
    def pred_ans(self, pth_path):

        # tokenizeation
        inputs = tokenizer(self.context,
                           self.ques,
                           truncation=True,
                           max_length=512,
                           return_tensors="pt")
        # inputs = tokenizer(context, question, truncation=True, max_length=512, do_lower_case=False, return_tensors="pt")
        # _batch_encode_plus() got an unexpected keyword argument 'do_lower_case'

        # Load checkpoin
        checkpoint = torch.load(pth_path, map_location=device)
        state_dict = checkpoint['state_dict']
        model.load_state_dict(state_dict, strict=False)
        model.cpu()

        ## Run model
        model.eval()
        outputs = model(**inputs)
        p1s, p2s = outputs[0], outputs[1]  # [1, clen+qlen]

        ## Get start/end idxs
        p1s = utils.masked_softmax(p1s,
                                   1 - inputs['token_type_ids'],
                                   dim=1,
                                   log_softmax=True)  # [1, clen+qlen]
        p2s = utils.masked_softmax(p2s,
                                   1 - inputs['token_type_ids'],
                                   dim=1,
                                   log_softmax=True)  # [1, clen+qlen]
        p1s, p2s = p1s.exp(), p2s.exp()
        s_idxs, e_idxs, top_probs = utils.get_ans_list_idx(
            p1s, p2s, max_len=5, num_answer=30)  # [1, num_answer]

        ### Get ans candidates ###
        all_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        ans_cand = []
        for j in range(30):  # iter candidates
            ans_jth_tokens = all_tokens[s_idxs[0, j]:(
                e_idxs[0, j] + 1)]  # Token list of one answer candidate
            ans_ids = tokenizer.convert_tokens_to_ids(ans_jth_tokens)
            ans = tokenizer.decode(ans_ids)
            ans_cand.append(ans)

        self.ans_cand = ans_cand
Exemple #8
0
    def forward(self, c, q, c_mask, q_mask):
        batch_size, c_len, _ = c.size()
        q_len = q.size(1)
        s = self.get_similarity_matrix(c, q)  # (batch_size, c_len, q_len)
        c_mask = c_mask.view(batch_size, c_len, 1)  # (batch_size, c_len, 1)
        q_mask = q_mask.view(batch_size, 1, q_len)  # (batch_size, 1, q_len)
        s1 = masked_softmax(s, q_mask, dim=2)  # (batch_size, c_len, q_len)
        s2 = masked_softmax(s, c_mask, dim=1)  # (batch_size, c_len, q_len)

        # (bs, c_len, q_len) x (bs, q_len, hid_size) => (bs, c_len, hid_size)
        a = torch.bmm(s1, q)
        # (bs, c_len, c_len) x (bs, c_len, hid_size) => (bs, c_len, hid_size)
        b = torch.bmm(torch.bmm(s1, s2.transpose(1, 2)), c)

        x = torch.cat([c, a, c * a, c * b], dim=2)  # (bs, c_len, 4 * hid_size)

        return x
Exemple #9
0
def test_masked_softmax():
    batched_input = torch.tensor([[1, 2, 3], [1, 1, 2], [3, 2, 1]]).float()
    batched_mask = torch.tensor([[1, 1, 0], [0, 1, 1], [1, 1, 1]]).float()
    batched_output = masked_softmax(batched_input, batched_mask, dim=1)

    # compare the result from masked_softmax with regular softmax with filtered values
    for input, mask, output in zip(batched_input, batched_mask,
                                   batched_output):
        assert output[output != 0].equal(F.softmax(input[mask == 1], dim=0))
Exemple #10
0
    def forward(self,
                encoded_state,
                candidate_action_reps,
                act_mask,
                eval_mode=None):
        '''
        
        Score each admissible action and sample one

        encoded_state:              batch x hidden
                                    encoded text description from textworld

        candidate_action_reps:      batch x num_admisisble_actions x max_action_len
                                    list size of batch of numpy arrays. Each numpy array is num_admisisble_actions x max_action_len  

        act_mask:                   Torch Tensor batch x num_actions
                                    binary mask 1 represents an admissible action, 0 represents a padded action      
        '''

        act_rep_tensor = torch.LongTensor(candidate_action_reps).to(
            self.device)
        batch, num_actions, _ = act_rep_tensor.size()

        #if using BERT
        #batch, num_actions, _ = encoded_act.size()

        encoded_act = self.command_encoder(act_rep_tensor)
        encoded_act = encoded_act.unsqueeze(1).view(
            batch, num_actions, -1)  # batch x num_admisisble_actions x hidden

        #concat admissibble commands with state
        encoded_state = encoded_state.unsqueeze(1).expand(
            -1, num_actions, -1)  #batch x num_actions x hidden
        #assert encoded_state.sum() == 0

        encoding = torch.cat([encoded_state, encoded_act],
                             dim=2)  #batch x num_actions x 2*hidden
        #encoding = torch.cat([encoded_state, encoded_act ],dim=2)

        #zero out all actions that are fillers
        encoding = (act_mask.unsqueeze(-1).expand(-1, -1, encoding.shape[-1]) *
                    encoding)  #batch x num_actions x 2*hidden

        #score each action and sample
        logits = self.action_scorer(encoding).squeeze(-1)  #batch x num_actions

        #masked softmax
        probs = masked_softmax(logits, act_mask, dim=1)  #batch x num_actions

        if self.use_gumbel_softmax:
            #Gumbel-max for action sampling
            u = torch.rand(probs.size(), dtype=probs.dtype)
            idxs = torch.argmax(probs.cpu() - torch.log(-torch.log(u)), dim=-1)
        else:
            idxs = probs.multinomial(num_samples=1)  #batch x 1

        return idxs, logits
Exemple #11
0
def ce_loss(pred, gt, ground_msk, s_msks):
    s_num = s_msks.sum(-1).clamp(min=1)
    pred = masked_softmax(pred, ground_msk)
    m_loss = F.binary_cross_entropy(pred, gt,
                                    reduction='none')  #[B x sN x cN*hN]
    m_loss = m_loss * ground_msk  # CE loss
    m_loss = m_loss.mean(-1)  #[B x sN]
    m_loss = (m_loss * s_msks).sum(-1) / s_num  # [B]
    m_loss = m_loss.mean()  # [1]
    return m_loss
Exemple #12
0
    def forward(self, premises, premises_mask, hypotheses, hypotheses_mask):
        """
        params
            premises: (S, N, H)
            hypotheses: (T, N, H)
            premises mask: (N, S)
            hypotheses maks: (N, T) 
        return
            new_premises: (S, N, H)
            new_hypotheses: (T, N, H)
        """
        premises = premises.transpose(0, 1)
        logging.debug(f"premises shape: {premises.shape}")
        # (N, S, H)
        hypotheses = hypotheses.transpose(0, 1)
        logging.debug(f"hypotheses shape: {hypotheses.shape}")
        # (N, T, H)

        attn_premises = torch.bmm(premises, hypotheses.transpose(1, 2))
        # (N, S, T)
        attn_hypotheses = attn_premises.transpose(1, 2)
        # (N, T, S)

        attn_premises = masked_softmax(attn_premises, premises_mask,
                                       hypotheses_mask)
        # (N, S, T)
        attn_hypotheses = masked_softmax(attn_hypotheses, hypotheses_mask,
                                         premises_mask)
        # (N, T, S)

        logging.debug(
            f"weight: {attn_premises.shape}, tensor: {hypotheses.shape}")
        new_premises = weighted_sum(attn_premises, hypotheses)
        # (N, S, H)
        new_hypotheses = weighted_sum(attn_hypotheses, premises)
        # (N, T, H)

        new_premises = new_premises.transpose(0, 1)
        # (S, N, H)
        new_hypotheses = new_hypotheses.transpose(0, 1)
        # (T, N, H)

        return new_premises, new_hypotheses, attn_premises, attn_hypotheses
Exemple #13
0
 def __call__(self, scores, oracle, **kwargs):
     with torch.no_grad():  # do not calculate gradient
         if self.aux_end:
             token_scores, stop_probs = scores
             p_oracle = oracle.distribution()
             correct_actions_mask = p_oracle.gt(0).unsqueeze(1).float()
             stop_probs = correct_actions_mask[:, :, self.
                                               end_idx]  # prevent invalid <end>
             token_scores = torch.clamp(token_scores, -40, 40)
             ps = util.masked_softmax(token_scores,
                                      correct_actions_mask,
                                      dim=2)
             samples = self.base_sampler((ps, stop_probs))
         else:
             p_oracle = oracle.distribution()
             correct_actions_mask = p_oracle.gt(0).unsqueeze(1).float()
             scores = torch.clamp(scores, -40, 40)
             ps = util.masked_softmax(scores, correct_actions_mask, dim=2)
             samples = self.base_sampler(ps)
     return samples
Exemple #14
0
    def forward(
        self,
        ctx: torch.Tensor,
        query: torch.Tensor,
        ctx_mask: torch.Tensor,
        query_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        ctx: (batch, ctx_seq_len, hidden_dim)
        query: (batch, query_seq_len, hidden_dim)
        ctx_mask: (batch, ctx_seq_len)
        query_mask: (batch, query_seq_len)

        output: (batch, ctx_seq_len, 4 * hidden_dim)
        """
        ctx_seq_len = ctx.size(1)
        query_seq_len = query.size(1)

        # (batch, ctx_seq_len, query_seq_len)
        similarity = self.trilinear_for_attention(ctx, query)
        # (batch, ctx_seq_len, query_seq_len)
        s_ctx = masked_softmax(similarity,
                               ctx_mask.unsqueeze(2).expand(
                                   -1, -1, query_seq_len),
                               dim=1)
        # (batch, ctx_seq_len, query_seq_len)
        s_query = masked_softmax(similarity,
                                 query_mask.unsqueeze(1).expand(
                                     -1, ctx_seq_len, -1),
                                 dim=2)
        # (batch, ctx_seq_len, hidden_dim)
        P = torch.bmm(s_query, query)
        # (batch, ctx_seq_len, hidden_dim)
        Q = torch.bmm(torch.bmm(s_query, s_ctx.transpose(1, 2)), ctx)

        # (batch, ctx_seq_len, 4 * hidden_dim)
        return torch.cat([ctx, P, ctx * P, ctx * Q], dim=2)
    def forward(self, user_review_inputs, item_review_inputs,
                user_review_masks, item_review_masks):
        """
        Args:
            user_review_inputs: [bz, ur_num, in_features]
            item_review_inputs: [bz, ir_num, in_features]
            user_review_masks: [bz, ur_num]
            item_review_masks: [bz, ir_num]
        """
        # aggregate item_review_inputs
        item_outputs, item_review_weights = self.item_aggregator(
            item_review_inputs, item_review_masks)

        # interaction score
        review_similarity_scores = self.bilinear(user_review_inputs,
                                                 item_review_inputs)
        user_review_scores, _ = torch.max(review_similarity_scores,
                                          dim=2)  #[bz, ur_num]
        user_review_weights = masked_softmax(user_review_scores,
                                             user_review_masks)
        user_outputs = attention_weighted_sum(user_review_weights,
                                              user_review_inputs)

        return user_outputs, item_outputs, user_review_weights, item_review_weights
Exemple #16
0
    c * W_cq, q.transpose(1, 2)
)  # matmul([batch_size, c_len, hidden_dim], [batch_size, hidden_dim, q_len])
S = S_c + S_q + S_cq + b  # [batch_size, c_len, q_len]

print(S.shape)

### Forward
idx_c = torch.rand((batch_size, c_len))
idx_q = torch.rand((batch_size, q_len))

mask_c = torch.zeros_like(idx_c) != idx_c
mask_q = torch.zeros_like(idx_q) != idx_q

mask_c = mask_c.unsqueeze(2)  # [batch_size, c_len, 1]
mask_q = mask_q.unsqueeze(1)  # [batch_size, 1, q_len]
S1 = masked_softmax(S, mask_q,
                    dim=2)  # row-wise softmax. [batch_size, c_len, q_len]
S2 = masked_softmax(S, mask_c,
                    dim=1)  # column-wise softmax. [batch_size, c_len, q_len]

# C2Q attention
A = torch.bmm(S1, q)  # [batch_size, c_len, hidden_dim]

# Q2C attention
S_temp = torch.bmm(S1, S2.transpose(1, 2))  # [batch_size, c_len, c_len]
B = torch.bmm(S_temp, c)  # [batch_size, c_len, hidden_dim]

# c: [batch_size, c_len, hidden_dim]
# A: [batch_size, c_len, hidden_dim]
# torch.mul(c, A): [batch_size, c_len, hidden_dim]
# torch.mul(c, B): [batch_size, c_len, hidden_dim]
G = torch.cat((c, A, torch.mul(c, A), torch.mul(c, B)),
Exemple #17
0
    def forward(self, idx_c, idx_q, y1s, y2s):
        """
            idx_c: [batch_size, c_len]
            idx_q: [batch_size, q_len]  
            y1s: [batch_size], start idxs of true answers
            y2s: [batch_size], end idxs of true answers  
        """ 
        mask_c = torch.zeros_like(idx_c) != idx_c
        mask_q = torch.zeros_like(idx_q) != idx_q
        
        # 1. Inpur embed layer
        emb_c = self.embed_inp(idx_c)  # [batch_size, c_len, embed_dim]
        emb_q = self.embed_inp(idx_q)  # [batch_size, q_len, embed_dim]
        # connection between embed layer and embed encoder
        emb_c = self.embed_conv(emb_c)  # [batch_size, c_len, embed_dim]
        emb_q = self.embed_conv(emb_q)  # [batch_size, q_len, embed_dim]
        
        # 2. Embed encoder
        emb_enc_c = self.embed_enc_c(emb_c, mask_c)  # [batch_size, c_len, hidden_dim]
        emb_enc_q = self.embed_enc_q(emb_q, mask_q)  # [batch_size, q_len, hidden_din]
        
        # 3. CQ attention
        G = self.attn(emb_enc_c, emb_enc_q, mask_c, mask_q)  # [batch_size, c_len, 4*hidden_dim]
        # Connection between attn and model encoder
        G = self.mod_conv(G)  # [batch_size, c_len, hidden_dim]

        # 4. Model encoders
        M0 = G
        for enc in self.mod_enc:
            M0 = enc(M0, mask_c)  # [batch_size, c_len, hidden_dim]
            
        M1 = M0
        for enc in self.mod_enc:
            M1 = enc(M1, mask_c)  # [batch_size, c_len, hidden_dim]
            
        M2 = M1
        for enc in self.mod_enc:
            M2 = enc(M2, mask_c)  # [batch_size, c_len, hidden_dim]
            
        # 5. Output      
        x1 = torch.cat([M0, M1], dim=2)  # [batch_size, c_len, 2*hidden_dim]
        x2 = torch.cat([M0, M2], dim=2)  # [batch_size, c_len, 2*hidden_dim]
        
        logits1 = self.fc1(x1).squeeze(2)  # [batch_size, c_len]
        logits2 = self.fc2(x2).squeeze(2)  # [batch_size, c_len]
             
        # log_p1 = masked_softmax(logits1, mask_c, dim=1, log_softmax=True)  # [batch_size, c_len]
        # log_p2 = masked_softmax(logits2, mask_c, dim=1, log_softmax=True)  # [batch_size, c_len]
        p1 = masked_softmax(logits1, mask_c, dim=1, log_softmax=False)  # [batch_size, c_len]
        p2 = masked_softmax(logits2, mask_c, dim=1, log_softmax=False)  # [batch_size, c_len]
        
        ### Loss ###     
        # Sometimes y1s/y2s are outside the model inputs (like -999), need to ignore these terms
        ignored_idx = p1.shape[1]
        y1s_clamp = torch.clamp(y1s, min=0, max=ignored_idx)  # limit value to [0, max_c_len]. '-999' converted to 0 
        y2s_clamp = torch.clamp(y2s, min=0, max=ignored_idx)
        loss_fn = nn.CrossEntropyLoss(ignore_index=ignored_idx)
        loss = (loss_fn(logits1, y1s_clamp) + loss_fn(logits2, y2s_clamp)) / 2 

        return loss, p1, p2
        
Exemple #18
0
    def forward(self,
                embs,
                toks,
                msks,
                s_msks,
                c_msks,
                dep_root_msk,
                imgs,
                i_msks,
                poses,
                pose_msks,
                i3d_rgb=None,
                face=None,
                face_msks=None,
                bbox_meta=None):
        '''
        params:
          embs: [B x 1 x L]
          toks: [B x 1 x L]
          msks: [B x 1 x L]
          s_msks: [B x sN x L]
          c_msks: [B x sN x cN]
          dep_root_msk: [B x sN x L]
          imgs: [B x cN x hN x 3 x 224 x 224]
          i_msks: [B x cN x hN]
          poses: [B x cN x hN x 17 x 2]
          pose_msks: [B x cN x hN x 17]
          i3d_rgb: [B x cN x hN x 1024]
          face: [B x cN x hN x 512]
          face_msks: [B x cN x hN]
          bbox_meta_5: [B x cN x hN x 3]
        '''
        B, sN, L = s_msks.shape
        cN, hN = imgs.size(1), imgs.size(2)
        # Text embedding
        embs, toks, msks = embs.view(-1, L), toks.view(-1, L), msks.view(
            -1, L)  # [B*sN x L]
        bert_x, _ = self.bert(embs,
                              toks,
                              msks,
                              output_all_encoded_layers=False)  # [B x L x 768]
        bert_x = gelu(bert_x).view(B, 1, L, 768)
        if 'someone' in self.t_feats:
            ts_x = bert_x * s_msks.view(B, sN, L, 1)  # [B x sN x L x 768]
            ts_x = torch.sum(ts_x, 2)  # [B x sN x 768]
        if 'action' in self.t_feats:
            ta_x = bert_x * dep_root_msk.view(B, sN, L,
                                              1)  # [B x sN x L x 768]
            ta_x = torch.sum(ta_x, 2)  # [B x sN x 768]
        tce_x = torch.cat([ta_x, ts_x], -1)  # [B x sN x sdim]

        # Text Projection
        s_a_x = self.s_a_proj(tce_x)  # [B x sN x H]
        s_s_x = self.s_s_proj(tce_x)  # [B x sN x H]

        # Auxiliary Gender Classifier
        if self.use_gender:
            g_x = self.gender_fc(ts_x)  # [B x sN x 1]
            self.gender_result = g_x.squeeze(-1)  # [B x sN]

        # Visual embedding
        i_x = torch.zeros((B * cN * hN, 0), device=bbox_meta.device)
        if 'img' in self.v_feats:
            imgs = imgs.view(-1, 3, 224, 224)
            img_x = self.act_conv(imgs)  # [B*cN*hN x 2048]
            i_x = torch.cat((i_x, img_x), -1)  #[B*cN*hN x (2048 + 4)]
        if "i3d_rgb" in self.v_feats:
            i3d_x = i3d_rgb.view(-1, 1024)  # [B*cN*hN x 1024]
            i_x = torch.cat((i_x, i3d_x), -1)  # [B*cN*hN x num_ftrs]
        if 'face' in self.v_feats:
            face_x = face.view(-1, 512)  # [B*cN*hN x 512]

        # Image projection
        i_a_x = self.i_a_proj(i_x)  # [B*cN*hN x H]
        i_a_x = i_a_x.view((B, cN * hN, -1))  # [B x cN*hN x H]
        i_s_x = self.i_s_proj(face_x)  # [B*cN*hN x H]
        i_s_x = i_s_x.view((B, cN * hN, -1))  # [B x cN*hN x H]
        if 'meta' in self.v_feats:
            meta_x = self.meta_proj(bbox_meta.view(-1, 4))  # [B*cN*hN x 50]

        # Character Grounding
        s_a_x = s_a_x.unsqueeze(2).repeat(1, 1, cN * hN,
                                          1).view(-1, self.hidden_dim)
        i_a_x = i_a_x.unsqueeze(1).repeat(1, sN, 1,
                                          1).view(-1, self.hidden_dim)
        s_s_x = s_s_x.unsqueeze(2).repeat(1, 1, cN * hN,
                                          1).view(-1, self.hidden_dim)
        i_s_x = i_s_x.unsqueeze(1).repeat(1, sN, 1,
                                          1).view(-1, self.hidden_dim)
        if 'meta' in self.v_feats:
            meta_x = meta_x.view(B, 1, cN * hN,
                                 50).repeat(1, sN, 1, 1)  # B x sN x cN*hN x 50

        f_a_x = s_a_x * i_a_x
        f_a_x = f_a_x.view(B, sN, cN * hN,
                           self.hidden_dim)  # [B x sN x cN*hN x H]
        f_s_x = s_s_x * i_s_x
        f_s_x = f_s_x.view(B, sN, cN * hN,
                           self.hidden_dim)  # [B x sN x cN*hN x H]
        f_t_x = torch.cat((meta_x, f_a_x, f_s_x), -1)  # [B x sN x cN*hN x fH]
        f_x = self.fuse_fc(f_t_x)  # [B x sN x cN*hN x 1]
        f_x = f_x.squeeze(-1)  # [B x sN x cN*hN]

        # Masking
        pred_ground = f_x * i_msks.view(B, 1, cN * hN) * s_msks.sum(
            -1, keepdim=True)

        # Character Re-Identification
        pred_chreid = 0
        to_divide = 1
        self.text_mat, self.vid_mat, self.vgv_mat = None, None, None
        # Text Re-Id
        if 'text' in self.char_reid:
            text_reid_x = ts_x.unsqueeze(1) * ts_x.unsqueeze(
                2)  # [B x sN x sN x 768]
            text_reid_x = self.text_reid_fc(text_reid_x)
            text_mat = text_reid_x.squeeze(-1)  # [B x sN x sN]
            self.text_mat = text_mat
            pred_chreid += self.text_mat
            to_divide += 1
        # Person identity representation
        if 'visual' in self.char_reid:
            vid_mat = self.vid_model(imgs.view(B, -1, 3, 224, 224),
                                     i_msks.view(B, -1),
                                     poses.view(B, -1, 17, 2),
                                     pose_msks.view(B, -1, 17),
                                     face.view(B, -1, 512),
                                     face_msks.view(B,
                                                    -1))  # B x cN*hN x cN*hN

            self.vid_mat = vid_mat
            p_msks = c_msks.view(B, sN, cN, 1).repeat(1, 1, 1,
                                                      hN).view(B, sN, cN * hN)
            i_mm = i_msks.view(B, 1, cN * hN)  # [B x 1 x cN*hN]
            p_msks = p_msks * i_mm  # [B x sN x cN*hN] (1 if exist, 0 if not)
            p_att = masked_softmax(
                f_x * 5, p_msks)  #[B x sN x cN*hN], sharpen f_x little.

            vgv_mat = torch.matmul(p_att, vid_mat)
            vgv_mat = vgv_mat * (
                1 - p_msks) - 0.1 * p_msks  # Mask for someone in same clip
            vgv_mat = torch.matmul(vgv_mat,
                                   p_att.transpose(1, 2))  # [B x sN x sN]
            vgv_mat = torch.sigmoid(vgv_mat)  # [B x sN x sN]  0 ~ 1
            self.vgv_mat = vgv_mat

            pred_chreid = (pred_chreid + vgv_mat) / to_divide

        return pred_ground, pred_chreid
Exemple #19
0
def valid_fn_list(model, data_loader, tokenizer, device, num_answer,
                  ans_thres):

    scores = {'loss': 0, 'f1': 0, 'prec': 0, 'rec': 0}
    len_iter = len(data_loader)
    n_samples = 0

    model.eval()
    epoch_dic = []
    with torch.no_grad():
        with tqdm(total=len_iter) as progress_bar:
            for _, batch in enumerate(data_loader):

                input_ids = batch['input_ids'].to(
                    device)  # [batch_size, c_len]
                attn_mask = batch['attention_mask'].to(
                    device)  # [batch_size, c_len]
                y1s = batch['token_starts'].to(device)  # [batch_size]
                y2s = batch['token_ends'].to(device)  # [batch_size]

                outputs = model(input_ids,
                                attention_mask=attn_mask,
                                start_positions=y1s,
                                end_positions=y2s)

                loss = outputs[0]
                p1s, p2s = outputs[1], outputs[2]  # [batch_size, c_len]
                scores['loss'] += loss.item()

                # Get start/end idxs
                mask_c = attn_mask != torch.zeros_like(attn_mask)
                p1s = utils.masked_softmax(
                    p1s, mask_c, dim=1,
                    log_softmax=True)  # [batch_size, c_len]
                p2s = utils.masked_softmax(
                    p2s, mask_c, dim=1,
                    log_softmax=True)  # [batch_size, c_len]
                p1s, p2s = p1s.exp(), p2s.exp()

                # Each record has [num_answer] candidates
                s_idxs, e_idxs, top_probs = utils.get_ans_list_idx(
                    p1s, p2s,
                    num_answer=num_answer)  # [batch_size, num_answer]

                for i in range(p1s.shape[0]):
                    all_tokens = tokenizer.convert_ids_to_tokens(
                        batch['input_ids'][i])

                    if (y1s[i] >= 0) and (
                            y2s[i] >=
                            0):  # When the answer passage not truncated
                        # True answer tokens
                        ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)]
                        answer_ids = tokenizer.convert_tokens_to_ids(
                            ans_piece_tokens)
                        answer = tokenizer.decode(answer_ids)
                        ans_tokens_true = answer.lower().split()

                        ## Predicted answer tokens
                        # Convert pred idxs to tokens for each record
                        record_preds_cand = []
                        for j in range(num_answer):  # iter candidates
                            ans_jth_tokens = all_tokens[s_idxs[i, j]:(
                                e_idxs[i, j] +
                                1)]  # Token list of one answer candidate
                            answer_ids = tokenizer.convert_tokens_to_ids(
                                ans_jth_tokens)
                            answer = tokenizer.decode(answer_ids)
                            ans_tokens_pred = answer.lower().split()
                            record_preds_cand.append(
                                ans_tokens_pred
                            )  # Get list of token list of [num_answer] answer candidates

                        record = {
                            'pid': batch['pids'][i],  # 1234
                            'trues': ans_tokens_true,  # ["ab", "ef"]
                            'preds':
                            record_preds_cand,  # [["af", "c"], ["b"], ..., ["fg", "ab"]]
                            'probs': top_probs[i].tolist()
                        }  # [0.3, 0.6, ..., 0.1]

                        epoch_dic.append(record)
                    n_samples += 1
                progress_bar.update(1)  # update progress bar

    # Group epoch_dic by pid
    key = lambda d: d['pid']
    epoch_dic.sort(key=key)

    epoch_dic_gp = []
    for key, group in itertools.groupby(epoch_dic, key=key):
        trues, preds, probs = [], [], []
        for g in group:
            trues.append(g['trues'])  # [["ab", "ef"], ["cd"]]
            preds = preds + g[
                'preds']  # g['preds']: [["ab","ef"], ["cd"], ["fg","b"], ["b","c"], ["gh"]]
            probs = probs + g['probs']  # g['probs']: [0.3, 0.4, 0.2, 0.7, 0.1]
        # Sort candidates by probs
        probs, preds = (list(t) for t in zip(
            *sorted(zip(probs, preds), key=lambda x: x[0], reverse=True)))
        # Keep candidates with probs > thres
        preds = [preds[i] for i in range(len(probs)) if probs[i] > ans_thres]
        # Remove duplicate candidates
        preds_ndup = []
        for p in preds:
            if p not in preds_ndup:
                preds_ndup.append(p)

        epoch_dic_gp.append({'pid': key, 'trues': trues, 'preds': preds_ndup})

    for ep in epoch_dic_gp:
        if len(ep['preds']) > 0:
            f1, pre, rec = utils.metric_ave_fpr(ep['preds'], ep['trues'])
            scores['f1'] += f1
            scores['prec'] += pre
            scores['rec'] += rec

    scores['loss'] = scores['loss'] / len_iter
    scores['f1'] = scores['f1'] / n_samples
    scores['prec'] = scores['prec'] / n_samples
    scores['rec'] = scores['rec'] / n_samples

    return scores
Exemple #20
0
def valid_fn(model, data_loader, tokenizer, device):

    scores = {'loss': 0, 'em': 0, 'f1': 0, 'prec': 0, 'rec': 0}
    len_iter = len(data_loader)
    n_samples = 0

    model.eval()

    with torch.no_grad():
        with tqdm(total=len_iter) as progress_bar:
            for j, batch in enumerate(data_loader):

                input_ids = batch['input_ids'].to(
                    device)  # [batch_size, c_len]
                attn_mask = batch['attention_mask'].to(
                    device)  # [batch_size, c_len]
                y1s = batch['token_starts'].to(device)  # [batch_size]
                y2s = batch['token_ends'].to(device)  # [batch_size]

                outputs = model(input_ids,
                                attention_mask=attn_mask,
                                start_positions=y1s,
                                end_positions=y2s)

                loss = outputs[0]
                p1s, p2s = outputs[1], outputs[2]  # [batch_size, c_len]

                # Get start/end idxs
                mask_c = attn_mask != torch.zeros_like(attn_mask)
                p1s = utils.masked_softmax(
                    p1s, mask_c, dim=1,
                    log_softmax=True)  # [batch_size, c_len]
                p2s = utils.masked_softmax(
                    p2s, mask_c, dim=1,
                    log_softmax=True)  # [batch_size, c_len]
                p1s, p2s = p1s.exp(), p2s.exp()
                s_idxs, e_idxs = utils.get_ans_idx(p1s, p2s)  # [batch_size]

                ans_tokens_pred, ans_tokens_true = [], []
                for i in range(p1s.shape[0]):
                    all_tokens = tokenizer.convert_ids_to_tokens(
                        batch['input_ids'][i])

                    if (y1s[i] >= 0) and (
                            y2s[i] >=
                            0):  # When the answer passage not truncated
                        # Predicted answer tokens
                        ans_piece_tokens = all_tokens[s_idxs[i]:(e_idxs[i] +
                                                                 1)]
                        answer_ids = tokenizer.convert_tokens_to_ids(
                            ans_piece_tokens)
                        answer = tokenizer.decode(answer_ids)
                        ans_tokens_pred = answer.lower().split()

                        # True answer tokens
                        ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)]
                        answer_ids = tokenizer.convert_tokens_to_ids(
                            ans_piece_tokens)
                        answer = tokenizer.decode(answer_ids)
                        ans_tokens_true = answer.lower().split()

                        scores['em'] += utils.metric_em(
                            ans_tokens_pred, ans_tokens_true)
                        f1, prec, rec = utils.metric_f1_pr(
                            ans_tokens_pred, ans_tokens_true)
                        scores['f1'] += f1
                        scores['prec'] += prec
                        scores['rec'] += rec

                    n_samples += 1

                scores['loss'] += loss.item()
                progress_bar.update(1)  # update progress bar

    scores['loss'] = scores['loss'] / len_iter
    scores['em'] = scores['em'] / n_samples
    scores['f1'] = scores['f1'] / n_samples
    scores['prec'] = scores['prec'] / n_samples
    scores['rec'] = scores['rec'] / n_samples

    return scores
Exemple #21
0
def train_fn(model, data_loader, optimizer, scheduler, tokenizer, clip,
             accum_step, device):

    scores = {'loss': 0, 'em': 0, 'f1': 0, 'prec': 0, 'rec': 0}
    len_iter = len(data_loader)
    n_samples = 0

    model.train()
    optimizer.zero_grad()

    with tqdm(total=len_iter) as progress_bar:
        for j, batch in enumerate(data_loader):

            input_ids = batch['input_ids'].to(device)  # [batch_size, c_len]
            attn_mask = batch['attention_mask'].to(
                device)  # [batch_size, c_len]
            y1s = batch['token_starts'].to(device)  # [batch_size]
            y2s = batch['token_ends'].to(device)  # [batch_size]

            outputs = model(input_ids,
                            attention_mask=attn_mask,
                            start_positions=y1s,
                            end_positions=y2s)

            loss = outputs[0]
            p1s, p2s = outputs[1], outputs[2]  # [batch_size, c_len]

            scores['loss'] += loss.item()

            loss = loss / accum_step  # loss gradients are accumulated by loss.backward() so we need to ave accumulated loss gradients
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(),
                                     clip)  # prevent exploding gradients

            # Gradient accumulation
            if (j + 1) % accum_step == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            # Get start/end idxs
            p1s = utils.masked_softmax(p1s, attn_mask, dim=1,
                                       log_softmax=True)  # [batch_size, c_len]
            p2s = utils.masked_softmax(p2s, attn_mask, dim=1,
                                       log_softmax=True)  # [batch_size, c_len]
            p1s, p2s = p1s.exp(), p2s.exp()
            s_idxs, e_idxs = utils.get_ans_idx(p1s, p2s)  # [batch_size]

            # ans_tokens_pred, ans_tokens_true = [], []
            for i in range(p1s.shape[0]):
                all_tokens = tokenizer.convert_ids_to_tokens(
                    batch['input_ids'][i])

                if (y1s[i] >= 0) and (
                        y2s[i] >= 0):  # When the answer passage not truncated
                    # Predicted answer tokens
                    ans_piece_tokens = all_tokens[s_idxs[i]:(e_idxs[i] + 1)]
                    answer_ids = tokenizer.convert_tokens_to_ids(
                        ans_piece_tokens)
                    answer = tokenizer.decode(answer_ids)
                    ans_tokens_pred = answer.lower().split()

                    # True answer tokens
                    ans_piece_tokens = all_tokens[y1s[i]:(y2s[i] + 1)]
                    answer_ids = tokenizer.convert_tokens_to_ids(
                        ans_piece_tokens)
                    answer = tokenizer.decode(answer_ids)
                    ans_tokens_true = answer.lower().split()

                    scores['em'] += utils.metric_em(ans_tokens_pred,
                                                    ans_tokens_true)
                    f1, prec, rec = utils.metric_f1_pr(ans_tokens_pred,
                                                       ans_tokens_true)
                    scores['f1'] += f1
                    scores['prec'] += prec
                    scores['rec'] += rec

                n_samples += 1

            progress_bar.update(1)  # update progress bar

    scores['loss'] = scores['loss'] / len_iter
    scores['em'] = scores['em'] / n_samples
    scores['f1'] = scores['f1'] / n_samples
    scores['prec'] = scores['prec'] / n_samples
    scores['rec'] = scores['rec'] / n_samples

    return scores
Exemple #22
0
    def forward(self, sentences1, sentences2):
        """
            sentences1 [batch, max_len]
            sentences2 [batch, max_len]
        """
        # get mask
        sentences1_mask = (sentences1 != self.padding_idx).long().to(
            self.device)  # [batch_size, max_len]
        sentences2_mask = (sentences2 != self.padding_idx).long().to(
            self.device)  # [batch_size, max_len]
        # input encoding
        sentences1_emb = self.emb(sentences1)  # [batch_size, max_len, dim]
        sentences2_emb = self.emb(sentences2)  # [batch_size, max_len, dim]
        sentences1_len = torch.sum(sentences1_mask,
                                   dim=-1).view(-1)  # [batch_size]
        sentences2_len = torch.sum(sentences2_mask,
                                   dim=-1).view(-1)  # [batch_size]
        #encoder
        s1_encoded = self.encoder_layer(
            sentences1_emb, sentences1_len)  # [batch_size, max_len_q1, dim]
        s2_encoded = self.encoder_layer(
            sentences2_emb, sentences2_len)  # [batch_size, max_len_q2, dim]
        # local inference
        # e_ij = a_i^Tb_j  (11)
        similarity_matrix = s1_encoded.bmm(
            s2_encoded.transpose(
                2, 1).contiguous())  # [batch_size, max_len_q1, max_len_q2]
        s1_s2_atten = masked_softmax(
            similarity_matrix,
            sentences2_mask)  # [batch_size, max_len_q1, max_len_q2]
        s2_s1_atten = masked_softmax(
            similarity_matrix.transpose(2, 1).contiguous(),
            sentences1_mask)  # [batch_size, max_len_q2, max_len_q1]

        # eij * bj
        a_hat = weighted_sum(s1_encoded, s1_s2_atten,
                             sentences1_mask)  # [batch_size, max_len_q1, dim]
        b_hat = weighted_sum(s2_encoded, s2_s1_atten,
                             sentences2_mask)  # [batch_size, max_len_q2, dim]

        # Enhancement of local inference information
        # ma = [a¯; a~; a¯ − a~; a¯ a~];
        # mb = [b¯; b~; b¯ − b~; b¯ b~]
        m_a = torch.cat(
            [s1_encoded, a_hat, s1_encoded - a_hat, s1_encoded * a_hat],
            dim=-1)  # [batch_size, max_len_q1, 4 * dim]
        m_b = torch.cat(
            [s2_encoded, b_hat, s2_encoded - b_hat, s2_encoded * b_hat],
            dim=-1)

        # 3.3 Inference Composition
        s1_projected = self.projection(m_a)  # [batch_size, max_len_q1, dim]
        s2_projected = self.projection(m_b)  # [batch_size, max_len_q2, dim]
        v_a = self.composition_layer(
            s1_projected, sentences1_len)  # [batch_size, max_len_q1, dim]
        v_b = self.composition_layer(
            s2_projected, sentences2_len)  # [batch_size, max_len_q2, dim]
        v_a_avg = torch.sum(v_a * sentences1_mask.unsqueeze(1).transpose(2, 1), dim=1)  \
                   / torch.sum(sentences1_mask, dim=1, keepdim = True) # q1_mask batch_size, 1, max_len_q1
        v_b_avg = torch.sum(v_b * sentences2_mask.unsqueeze(1).transpose(2, 1), dim=1) \
                   / torch.sum(sentences2_mask, dim=1, keepdim = True)
        v_a_max, _ = replace_masked(v_a, sentences1_mask,
                                    -1e7).max(dim=1)  # [batch_size, dim]
        v_b_max, _ = replace_masked(v_b, sentences2_mask, -1e7).max(dim=1)

        v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max],
                      dim=1)  # [batch_size, dim * 4]
        # predict
        logits = self.predict_fc(v)
        return logits
    def forward(self,
                user_sent_inputs,
                item_sent_inputs,
                user_sent_masks,
                item_sent_masks,
                debug=False):
        """
        Args:
            user_sent_inputs: [bz, ur_num, us_num, in_features]
            item_sent_inputs: [bz, ir_num, is_num, in_features]
            user_sent_masks: [bz, ur_num, us_num]
            item_sent_masks: [bz, ir_num, is_num]

        returns:
            user_review_outputs: [bz, ur_num, out_features]
            item_review_outputs: [bz, ir_num, out_features]
        """
        # aggregate item_sent_inputs
        bz, ir_num, is_num, in_features = list(item_sent_inputs.size())
        item_sent_inputs = item_sent_inputs.view(bz * ir_num, is_num,
                                                 in_features)
        item_sent_masks = item_sent_masks.view(bz * ir_num, is_num)
        item_review_outputs, item_sent_weights, item_all_sent_weights = self.item_aggregator(
            item_sent_inputs, item_sent_masks)

        # interaction score
        bz, ur_num, us_num, in_features = list(user_sent_inputs.size())
        item_all_sent_inputs = item_sent_inputs.view(bz, ir_num * is_num,
                                                     in_features)
        item_all_sent_inputs = item_all_sent_inputs * item_all_sent_weights.unsqueeze(
            -1)

        chunks_of_user_sent_inputs = torch.chunk(user_sent_inputs,
                                                 dim=1,
                                                 chunks=ur_num)
        chunks_of_user_sent_masks = torch.chunk(user_sent_masks,
                                                dim=1,
                                                chunks=ur_num)
        user_review_outputs = []
        user_sent_weights = []
        for user_sent_inputs_pr, user_sent_masks_pr in zip(
                chunks_of_user_sent_inputs, chunks_of_user_sent_masks):
            user_sent_inputs_pr = user_sent_inputs_pr.squeeze(
                1)  #[bz, us_num, in_features]
            user_sent_masks_pr = user_sent_masks_pr.squeeze(1)  #[bz, us_num]

            ui_similarity_score_pr = self.bilinear(
                user_sent_inputs_pr,
                item_all_sent_inputs)  #[bz, us_num, ir_num*is_num]
            user_sent_scores_pr, _ = torch.max(ui_similarity_score_pr,
                                               dim=2)  #[bz, us_num]
            user_sent_weights_pr = masked_softmax(user_sent_scores_pr,
                                                  user_sent_masks_pr)

            user_review_output_pr = attention_weighted_sum(
                user_sent_weights_pr, user_sent_inputs_pr)
            user_review_output_pr = user_review_output_pr.unsqueeze(
                1)  # NOTE: dirty implementation, [bz, 1, out_features]
            user_sent_weights_pr = user_sent_weights_pr.unsqueeze(
                1)  # [bz, 1, us_num]
            user_review_outputs.append(
                user_review_output_pr)  # list of [bz, 1, in_features]
            user_sent_weights.append(user_sent_weights_pr)
        #print("user review outputs element", user_review_outputs[0].shape)
        user_review_outputs = torch.cat(user_review_outputs, dim=1)
        user_sent_weights = torch.cat(user_sent_weights, dim=1)

        if debug:
            print("UnbalancedCoAttentionAggregator: ")
            print("user_review_outputs", user_review_outputs)
            print("item_review_outputs", item_review_outputs)
            print("user_sent_weights", user_sent_weights)
            print("item_sent_weights", item_sent_weights)
            print("item_all_sent_weights", item_all_sent_weights)

        return user_review_outputs, item_review_outputs, user_sent_weights, item_sent_weights, item_all_sent_weights