Beispiel #1
0
  def forward(self, input, training=False):
    x = input['content'] 
    #print(x.shape)
    x_mask = x.eq(0)
    batch_size = x.size(0)

    x = self.unk_aug(x, x_mask)

    if FLAGS.rnn_no_padding:
      x_mask = torch.zeros_like(x, dtype=torch.uint8)

    x = self.encode(input, x_mask, training=training)

    if FLAGS.use_label_att:
      label_emb = self.label_embedding.weight
      label_seq = lele.tile(label_emb.unsqueeze(0), 0, batch_size)
      # TODO label rnn 
      for i in range(FLAGS.label_hop):
        x = self.att_dot_attentions[i](x, label_seq, torch.zeros(batch_size, self.label_emb_height).byte().cuda())
        x = self.att_encodes[i](x, x_mask)

    if FLAGS.use_self_match:
       x = self.match_dot_attention(x, x, x_mask) 
       x = self.match_encode(x, x_mask) 

    x = self.pooling(x, x_mask)
    
    x = self.logits(x)  

    return x
Beispiel #2
0
  def forward(self, input, training=False):
    #print('------------', input['source'])
    #print(input['id'])
    x = input['content']
    #print(x) 
    #print(x.shape)
    x_mask = x.eq(0)
    batch_size = x.size(0)
    max_c_len = x.size(1)

    x = self.unk_aug(x, x_mask)

    if FLAGS.rnn_no_padding:
      x_mask = torch.zeros_like(x, dtype=torch.uint8)

    x = self.encode(input, x_mask, training=training)

    if FLAGS.use_label_att:
      label_emb = self.label_embedding.weight
      label_seq = lele.tile(label_emb.unsqueeze(0), 0, batch_size)
      x2_mask = torch.zeros(batch_size, self.label_emb_height).byte().cuda()
      if not FLAGS.use_label_rnn:
        label_seq = self.label_forward(label_seq)
      else:
        label_seq = self.label_forward(label_seq, x2_mask)
      # Align and aggregate
      c_check = x
      q = label_seq
    else:
      c_check = x
    
    #print(c_check.shape, q.shape, x2_mask.shape)
    for i in range(FLAGS.hop):
      if FLAGS.use_label_att:
        q_tilde = self.interactive_aligners[i].forward(c_check, q, x2_mask)
        c_bar = self.interactive_SFUs[i].forward(c_check, torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
      else:
        c_bar = c_check
      if FLAGS.use_self_match:
        c_tilde = self.self_aligners[i].forward(c_bar, x_mask)
        c_hat = self.self_SFUs[i].forward(c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde], 2))
        c_check = self.aggregate_rnns[i].forward(c_hat, x_mask)
    
    x = c_check

    x = self.pooling(x, x_mask)

    x = self.logits(x)  

    return x
Beispiel #3
0
  def forward(self, x, mask=None, calc_word_scores=False):
    results = []
    self.word_scores = []
    for i, pooling in enumerate(self.poolings):
      result = pooling(x, mask)
      if not self.is_poolings_list[i]:
          result = lele.tile(result.unsqueeze(1), 1, self.num_poolings)
      results.append(result)
      if calc_word_scores:
        self.word_scores.append(melt.get_words_importance(outputs, sequence_length, top_k=self.top_k, method=self.names[i]))

    result = torch.cat(results, -1)
    self.encode = result
    return result
Beispiel #4
0
 def forward(self, x, x_mask):
     """
     Args:
         x: batch * len * dim
         x_mask: batch * len (1 for padding, 0 for true)
     Output:
         alpha: batch * len
     """
     scores = self.FFN(x)
     x_mask = lele.tile(x_mask.unsqueeze(-1), -1, self.num_poolings)
     scores.data.masked_fill_(x_mask.data, -float('inf'))
     scores = scores.transpose(-2, -1)
     alpha = F.softmax(scores, dim=-1)
     return alpha.bmm(x)
Beispiel #5
0
 def forward(self, x, x_mask):
     """
     Args:
         x: batch * len * hdim
         x_mask: batch * len (1 for padding, 0 for true)
     Output:
         alpha: batch * len
     """
     # TODO why need contiguous
     x_mask = lele.tile(x_mask.unsqueeze(-1), -1, self.num_poolings)
     x = x.contiguous() 
     x_flat = x.view(-1, x.size(-1))
     scores = self.linear(x_flat).view(x.size(0), x.size(1), self.num_poolings)
     scores = scores.transpose(-2, -1)
     scores.data.masked_fill_(x_mask.data, -float('inf'))
     alpha = F.softmax(scores, dim=-1)
     self.alpha = alpha
     return alpha.bmm(x)