def forward( self, # type: ignore encoded_boxes: torch.Tensor, encoded_boxes_mask: torch.Tensor, encoded_boxes_pooled: torch.Tensor, encoded_text: torch.Tensor, encoded_text_mask: torch.Tensor, encoded_text_pooled: torch.Tensor, pooled_boxes_and_text: torch.Tensor, labels: Optional[torch.Tensor] = None, label_weights: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: logits = self.classifier(pooled_boxes_and_text) output = { "logits": logits, "probs": torch.sigmoid(logits), } if labels is not None and label_weights is not None: label_mask = labels > 1 # 0 is padding, 1 is OOV, which we want to ignore from allennlp.nn import util weighted_labels = util.masked_index_replace( logits.new_zeros(logits.size() + (1, )), labels.clamp(min=0), label_mask, label_weights.unsqueeze(-1), ).squeeze(-1) # weighted_labels now has shape (batch_size, num_labels). We need to ignore the first # two columns of this in our loss function and accuracy metric. The first column is a # padding label, and the second column is an OOV label. We want the loss function to # be computed on every other label. binary_label_mask = weighted_labels.new_ones(logits.size()) binary_label_mask[:, 0] = 0 binary_label_mask[:, 1] = 0 output[ "loss"] = torch.nn.functional.binary_cross_entropy_with_logits( logits, weighted_labels, weight=binary_label_mask, reduction="sum") / logits.size(0) self.f1_metric(logits, weighted_labels, binary_label_mask.bool()) self.vqa_metric(logits, labels, label_weights) return output
def _compute_loss_and_metrics( self, batch_size: int, outputs: torch.Tensor, label: torch.Tensor, label_weights: Optional[torch.Tensor] = None, ): if label is not None and label_weights is not None: logits = outputs["logits"] label_mask = label > 1 # 0 is padding, 1 is OOV, which we want to ignore weighted_labels = util.masked_index_replace( logits.new_zeros(logits.size() + (1, )), label.clamp(min=0), label_mask, label_weights.unsqueeze(-1), ).squeeze(-1) # weighted_labels now has shape (batch_size, num_labels). We need to ignore the first # two columns of this in our loss function and accuracy metric. The first column is a # padding label, and the second column is an OOV label. We want the loss function to # be computed on every other label. binary_label_mask = weighted_labels.new_ones(logits.size()) binary_label_mask[:, 0] = 0 binary_label_mask[:, 1] = 0 outputs["loss"] = ( torch.nn.functional.binary_cross_entropy_with_logits( logits, weighted_labels, weight=binary_label_mask, reduction="sum") / batch_size) self.f1_metric(logits, weighted_labels, binary_label_mask.bool()) self.vqa_metric(logits, label, label_weights) return outputs