Ejemplo n.º 1
0
    def forward(
        self,  # type: ignore
        encoded_boxes: torch.Tensor,
        encoded_boxes_mask: torch.Tensor,
        encoded_boxes_pooled: torch.Tensor,
        encoded_text: torch.Tensor,
        encoded_text_mask: torch.Tensor,
        encoded_text_pooled: torch.Tensor,
        pooled_boxes_and_text: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        label_weights: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        logits = self.classifier(pooled_boxes_and_text)

        output = {
            "logits": logits,
            "probs": torch.sigmoid(logits),
        }

        if labels is not None and label_weights is not None:
            label_mask = labels > 1  # 0 is padding, 1 is OOV, which we want to ignore

            from allennlp.nn import util

            weighted_labels = util.masked_index_replace(
                logits.new_zeros(logits.size() + (1, )),
                labels.clamp(min=0),
                label_mask,
                label_weights.unsqueeze(-1),
            ).squeeze(-1)

            # weighted_labels now has shape (batch_size, num_labels).  We need to ignore the first
            # two columns of this in our loss function and accuracy metric.  The first column is a
            # padding label, and the second column is an OOV label.  We want the loss function to
            # be computed on every other label.
            binary_label_mask = weighted_labels.new_ones(logits.size())
            binary_label_mask[:, 0] = 0
            binary_label_mask[:, 1] = 0

            output[
                "loss"] = torch.nn.functional.binary_cross_entropy_with_logits(
                    logits,
                    weighted_labels,
                    weight=binary_label_mask,
                    reduction="sum") / logits.size(0)

            self.f1_metric(logits, weighted_labels, binary_label_mask.bool())
            self.vqa_metric(logits, labels, label_weights)

        return output
Ejemplo n.º 2
0
    def _compute_loss_and_metrics(
        self,
        batch_size: int,
        outputs: torch.Tensor,
        label: torch.Tensor,
        label_weights: Optional[torch.Tensor] = None,
    ):
        if label is not None and label_weights is not None:
            logits = outputs["logits"]
            label_mask = label > 1  # 0 is padding, 1 is OOV, which we want to ignore

            weighted_labels = util.masked_index_replace(
                logits.new_zeros(logits.size() + (1, )),
                label.clamp(min=0),
                label_mask,
                label_weights.unsqueeze(-1),
            ).squeeze(-1)

            # weighted_labels now has shape (batch_size, num_labels).  We need to ignore the first
            # two columns of this in our loss function and accuracy metric.  The first column is a
            # padding label, and the second column is an OOV label.  We want the loss function to
            # be computed on every other label.
            binary_label_mask = weighted_labels.new_ones(logits.size())
            binary_label_mask[:, 0] = 0
            binary_label_mask[:, 1] = 0

            outputs["loss"] = (
                torch.nn.functional.binary_cross_entropy_with_logits(
                    logits,
                    weighted_labels,
                    weight=binary_label_mask,
                    reduction="sum") / batch_size)

            self.f1_metric(logits, weighted_labels, binary_label_mask.bool())
            self.vqa_metric(logits, label, label_weights)

        return outputs