示例#1
0
    def forward(self, src, tgt, seg):
        """
        Args:
            src: [batch_size x seq_length]
            tgt: [batch_size]
            seg: [batch_size x seq_length]
        """
        # Embedding.
        emb = self.embedding(src, seg)
        # Encoder.
        output = self.encoder(emb, seg)
        # Target.
        features_0, features_1 = output
        features_0 = pooling(features_0, seg[0], self.pooling_type)
        features_1 = pooling(features_1, seg[1], self.pooling_type)

        vectors_concat = []

        # concatenation
        vectors_concat.append(features_0)
        vectors_concat.append(features_1)
        # difference:
        vectors_concat.append(torch.abs(features_0 - features_1))
        # multiplication:
        vectors_concat.append(features_0 * features_1)

        features = torch.cat(vectors_concat, 1)

        logits = self.classifier(features)

        if tgt is not None:
            loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1))
            return loss, logits
        else:
            return None, logits
示例#2
0
    def forward(self, memory_bank, tgt, seg):
        """
        Args:
            memory_bank: [batch_size x seq_length x hidden_size]
            tgt: [batch_size]

        Returns:
            loss: Classification loss.
            correct: Number of sentences that are predicted correctly.
        """

        output = pooling(memory_bank, seg, self.pooling_type)
        output = torch.tanh(self.linear_1(output))
        logits = self.linear_2(output)

        loss = self.criterion(self.softmax(logits), tgt)
        correct = self.softmax(logits).argmax(dim=-1).eq(tgt).sum()

        return loss, correct
示例#3
0
 def forward(self, src, tgt, seg, soft_tgt=None):
     """
     Args:
         src: [batch_size x seq_length]
         tgt: [batch_size]
         seg: [batch_size x seq_length]
     """
     # Embedding.
     emb = self.embedding(src, seg)
     # Encoder.
     output = self.encoder(emb, seg)
     # Target.
     output = pooling(output, seg, self.pooling_type)
     output = torch.tanh(self.output_layers_1[self.dataset_id](output))
     logits = self.output_layers_2[self.dataset_id](output)
     if tgt is not None:
         loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1))
         return loss, logits
     else:
         return None, logits
示例#4
0
 def forward(self, src, tgt, seg):
     """
     Args:
         src: [batch_size x seq_length]
         tgt: [batch_size]
         seg: [batch_size x seq_length]
     """
     # Embedding.
     emb = self.embedding(src, seg)
     # Encoder.
     output = self.encoder(emb, seg)
     # Target.
     output = pooling(output, seg, self.pooling_type)
     output = torch.tanh(self.output_layer_1(output))
     logits = self.output_layer_2(output)
     if tgt is not None:
         probs_batch = nn.Sigmoid()(logits)
         loss = nn.BCELoss()(probs_batch, tgt)
         return loss, logits
     else:
         return None, logits
示例#5
0
 def forward(self, src, tgt, seg, soft_tgt=None):
     """
     Args:
         src: [batch_size x seq_length]
         tgt: [batch_size]
         seg: [batch_size x seq_length]
     """
     # Embedding.
     emb = self.embedding(src, seg)
     # Encoder.
     output = self.encoder(emb, seg)
     # Target.
     output = pooling(output, seg, self.pooling_type)
     output = torch.tanh(self.output_layer_1(output))
     logits = self.output_layer_2(output)
     if tgt is not None:
         if self.soft_targets and soft_tgt is not None:
             loss = self.soft_alpha * nn.MSELoss()(logits, soft_tgt) + \
                    (1 - self.soft_alpha) * nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1))
         else:
             loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1))
         return loss, logits
     else:
         return None, logits
示例#6
0
    def forward(self, src, seg):
        emb = self.embedding(src, seg)
        output = self.encoder(emb, seg)
        output = pooling(output, seg, self.pooling_type)

        return output