示例#1
0
    def forward(self, input_data, labels=None):
        if self.use_mixup or self.use_cutmix:
            output, coeffs = self.model(input_data)
        else:
            output = self.model(input_data)
        # Calculate loss in full precision.
        output = output.float()
        if labels is None:
            return output

        loss_items = {}
        log_preds = torch.nn.functional.log_softmax(output, dim=1)

        if self.use_mixup or self.use_cutmix:
            all_labels, weights = self.model.mix_labels(labels, coeffs)
            classification_loss = self.loss(log_preds, all_labels, weights)
        else:
            classification_loss = self.loss(log_preds, labels)
        loss_items['classification_loss'] = (
            1.0 - self.label_smoothing) * classification_loss

        if self.label_smoothing > 0.0:
            # Cross entropy between uniform distribution and output distribution.
            loss_items["smoothing_loss"] = -torch.mean(
                log_preds) * self.label_smoothing
            final_loss = loss_items["smoothing_loss"] + loss_items[
                "classification_loss"]
        else:
            final_loss = loss_items["classification_loss"]

        with torch.no_grad():
            acc = utils.accuracy(output, labels)
        return acc, poptorch.identity_loss(
            final_loss, reduction='none'), tuple(loss_items.values())
示例#2
0
 def forward(self, logits, target, target_mask):
     y_true = torch.nn.functional.one_hot(target.long(),
                                          self.vocab_size).to(logits.dtype)
     y_true = y_true * self.offset + self.other_value
     y_pred = torch.log_softmax(logits, -1)
     loss_pre = y_true * (torch.log(y_true) -
                          y_pred) * target_mask.unsqueeze(-1)
     loss = torch.sum(loss_pre)
     loss = loss / target_mask.int().sum()
     return pt.identity_loss(loss, reduction='mean')
示例#3
0
 def forward(self, args, loss_inputs=None):
     output = self.model(args)
     # Calculate loss in full precision
     output = output.float()
     if loss_inputs is None:
         return output
     else:
         log_preds = torch.nn.functional.log_softmax(output, dim=1)
         # NOTE: after popART batch + replica reduction dispatch returning loss/final loss hack can be removed
         returning_loss = final_loss = (1.0 -
                                        self.label_smoothing) * self.loss(
                                            log_preds, loss_inputs)
         final_loss = final_loss * self.replicas
         if self.label_smoothing > 0.0:
             # cross entropy between uniform distribution and output distribution
             smoothing_loss = -torch.mean(log_preds) * self.label_smoothing
             final_loss = final_loss + smoothing_loss
             returning_loss = returning_loss + smoothing_loss
         poptorch.identity_loss(final_loss, reduction='mean')
         return output, returning_loss
示例#4
0
    def forward(self, x: torch.Tensor, target: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, ...]:
        x = self.conv1(x)
        x = self.cspdark1(x)
        x = self.cspdark2(x)
        p3 = self.cspdark3(x)
        p4 = self.cspdark4(p3)
        p5 = self.cspdark5(p4)

        if self.calculate_loss:
            loss = poptorch.identity_loss(self.dummy_loss(p5) + self.dummy_loss(p4) + self.dummy_loss(p3), 'sum')
            return p5, p4, p3, loss

        return p5, p4, p3
示例#5
0
    def forward(self,
                input_ids,
                attention_mask,
                token_type_ids,
                masked_lm_positions,
                masked_lm_labels=None,
                next_sentence_label=None):
        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }

        outputs = self.model.bert(**inputs)
        sequence_output, pooled_output = outputs[:2]

        # Select only the masked tokens for the classifier
        masked_output = gather_indices(sequence_output, masked_lm_positions)

        prediction_scores, sequential_relationship_score = self.model.cls(
            masked_output, pooled_output)
        outputs = (
            prediction_scores,
            sequential_relationship_score,
        ) + outputs[2:]

        if masked_lm_labels is not None and next_sentence_label is not None:
            masked_lm_loss = F.cross_entropy(prediction_scores.view(
                -1, self.config.vocab_size),
                                             masked_lm_labels.view(-1),
                                             ignore_index=0).float()
            next_sentence_loss = F.cross_entropy(
                sequential_relationship_score.view(-1, 2),
                next_sentence_label.view(-1)).float()
            total_loss = poptorch.identity_loss(masked_lm_loss +
                                                next_sentence_loss,
                                                reduction="none")

            next_sentence_acc = accuracy(
                sequential_relationship_score.view([-1, 2]),
                next_sentence_label.view(-1))
            # masked_lm_labels: 0 if corresponding token not masked, original value otherwise
            masked_lm_acc = accuracy_masked(
                prediction_scores.view(
                    [-1, self.config.mask_tokens, self.config.vocab_size]),
                masked_lm_labels, 0)
            outputs = (total_loss, masked_lm_loss, next_sentence_loss,
                       masked_lm_acc, next_sentence_acc)

        return outputs
示例#6
0
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], target: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, ...]:
        p5, p4, p3 = x

        p5, x = self.SPP(p5)
        p4, x = self.cspUp1(p4, x)
        x, p3 = self.cspUp2(p3, x)

        p4, x = self.cspDown1(p4, x)
        p5, x = self.cspDown2(p5, x)

        if self.calculate_loss:
            loss = poptorch.identity_loss(self.dummy_loss(p5) + self.dummy_loss(p4) + self.dummy_loss(p3), 'sum')
            return p5, p4, p3, loss
        return (p5, p4, p3)
示例#7
0
 def forward(self, args, loss_inputs=None):
     output = self.model(args)
     if loss_inputs is None:
         return output
     else:
         # Calculate loss in full precision
         output = output.float()
         if self.label_smoothing > 0.0:
             # cross entropy between uniform distribution and output distribution
             log_preds = torch.nn.functional.log_softmax(output, dim=1)
             smoothing_loss = self.reduction_op(
                 -log_preds.mean(dim=1), dim=0) * self.label_smoothing
             classification_loss = (
                 1.0 - self.label_smoothing) * self.label_smoothing_loss(
                     log_preds, loss_inputs)
             final_loss = smoothing_loss + classification_loss
         else:
             final_loss = self.loss(output, loss_inputs)
         return output, poptorch.identity_loss(final_loss *
                                               self.loss_scaling,
                                               reduction='none')
示例#8
0
def custom_loss(output, target):
    # Mean squared error with a scale
    loss = output - target
    loss = loss * loss * 5
    return poptorch.identity_loss(loss, reduction="mean")
示例#9
0
    def forward(self, text, image=None, mask=None):
        if exists(image) and not is_empty(image):
            is_raw_image = len(image.shape) == 4

            if is_raw_image:
                image_size = self.vae.image_size
                image = self.vae.get_codebook_indices(image)

            image_len = image.shape[1]
            image_emb = self.image_emb(image)

            image_emb += self.image_pos_emb(image_emb)

        # make sure padding in text tokens get unique padding token id

        text_range = torch.arange(
            self.text_seq_len) + (self.num_text_tokens - self.text_seq_len)
        text = torch.where(text == 0, text_range, text)

        # add <bos>

        text = F.pad(text, (1, 0), value=0)

        tokens = self.text_emb(text)
        tokens += self.text_pos_emb(torch.arange(text.shape[1]))

        seq_len = tokens.shape[1]

        if exists(image) and not is_empty(image):
            tokens = torch.cat((tokens, image_emb), dim=1)
            seq_len += image_len

        # when training, the length exceeds the total text + image length
        # remove the last token, since it needs not to be trained

        if self.training:
            seq_len -= 1
            tokens = tokens[:, :-1]

        out = self.transformer(tokens)

        logits = self.to_logits(out)

        # mask logits to make sure text predicts text (except last token), and image predicts image

        logits_mask = self.logits_mask[:, :seq_len]
        if self.fp16:
            max_neg_value = -torch.finfo(torch.float16).max
        else:
            max_neg_value = -torch.finfo(torch.float32).max
        logits.masked_fill_(logits_mask, max_neg_value)

        if not self.training:
            return logits

        assert exists(image), 'when training, image must be supplied'

        offsetted_image = image + self.num_text_tokens
        labels = torch.cat((text[:, 1:], offsetted_image), dim=1)

        logits = rearrange(logits, 'b n c -> b c n')

        loss_text = F.cross_entropy(
            logits[:, :, :self.text_seq_len].permute([0, 2, 1]).reshape(
                [-1, self.total_tokens]),
            labels[:, :self.text_seq_len].reshape(-1))
        loss_img = F.cross_entropy(
            logits[:, :, self.text_seq_len:].permute([0, 2, 1]).reshape(
                [-1, self.total_tokens]),
            labels[:, self.text_seq_len:].reshape(-1))
        loss = (loss_text +
                self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)

        return poptorch.identity_loss(loss, reduction='none')