def forward(self, src, tgt, seg): """ Args: src: [batch_size x seq_length] tgt: [batch_size] seg: [batch_size x seq_length] """ # Embedding. emb = self.embedding(src, seg) # Encoder. output = self.encoder(emb, seg) # Target. features_0, features_1 = output features_0 = pooling(features_0, seg[0], self.pooling_type) features_1 = pooling(features_1, seg[1], self.pooling_type) vectors_concat = [] # concatenation vectors_concat.append(features_0) vectors_concat.append(features_1) # difference: vectors_concat.append(torch.abs(features_0 - features_1)) # multiplication: vectors_concat.append(features_0 * features_1) features = torch.cat(vectors_concat, 1) logits = self.classifier(features) if tgt is not None: loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1)) return loss, logits else: return None, logits
def forward(self, memory_bank, tgt, seg): """ Args: memory_bank: [batch_size x seq_length x hidden_size] tgt: [batch_size] Returns: loss: Classification loss. correct: Number of sentences that are predicted correctly. """ output = pooling(memory_bank, seg, self.pooling_type) output = torch.tanh(self.linear_1(output)) logits = self.linear_2(output) loss = self.criterion(self.softmax(logits), tgt) correct = self.softmax(logits).argmax(dim=-1).eq(tgt).sum() return loss, correct
def forward(self, src, tgt, seg, soft_tgt=None): """ Args: src: [batch_size x seq_length] tgt: [batch_size] seg: [batch_size x seq_length] """ # Embedding. emb = self.embedding(src, seg) # Encoder. output = self.encoder(emb, seg) # Target. output = pooling(output, seg, self.pooling_type) output = torch.tanh(self.output_layers_1[self.dataset_id](output)) logits = self.output_layers_2[self.dataset_id](output) if tgt is not None: loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1)) return loss, logits else: return None, logits
def forward(self, src, tgt, seg): """ Args: src: [batch_size x seq_length] tgt: [batch_size] seg: [batch_size x seq_length] """ # Embedding. emb = self.embedding(src, seg) # Encoder. output = self.encoder(emb, seg) # Target. output = pooling(output, seg, self.pooling_type) output = torch.tanh(self.output_layer_1(output)) logits = self.output_layer_2(output) if tgt is not None: probs_batch = nn.Sigmoid()(logits) loss = nn.BCELoss()(probs_batch, tgt) return loss, logits else: return None, logits
def forward(self, src, tgt, seg, soft_tgt=None): """ Args: src: [batch_size x seq_length] tgt: [batch_size] seg: [batch_size x seq_length] """ # Embedding. emb = self.embedding(src, seg) # Encoder. output = self.encoder(emb, seg) # Target. output = pooling(output, seg, self.pooling_type) output = torch.tanh(self.output_layer_1(output)) logits = self.output_layer_2(output) if tgt is not None: if self.soft_targets and soft_tgt is not None: loss = self.soft_alpha * nn.MSELoss()(logits, soft_tgt) + \ (1 - self.soft_alpha) * nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1)) else: loss = nn.NLLLoss()(nn.LogSoftmax(dim=-1)(logits), tgt.view(-1)) return loss, logits else: return None, logits
def forward(self, src, seg): emb = self.embedding(src, seg) output = self.encoder(emb, seg) output = pooling(output, seg, self.pooling_type) return output