def test_result_map(tmpdir):
    result = TrainResult()
    result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
    result.rename_keys({'x1': 'y1', 'x2': 'y2'})

    assert 'x1' not in result
    assert 'x2' not in result
    assert 'y1' in result
    assert 'y2' in result
Exemple #2
0
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        outputs = self.forward(input_ids, attention_mask)
        loss = self.loss_function(predictions=outputs,
                                  targets=labels,
                                  weight=self.hparams.loss_weight)

        result = TrainResult(loss)
        metrics = {"train_loss": loss}

        result.log_dict(metrics, prog_bar=True, logger=self.logger.experiment)
        return result
Exemple #3
0
    def training_step(self, batch, batch_idx):
        if self.hparams.data.augment:  # Using only first augmentation
            batch = batch[0]
        logits = self(batch)
        labels = batch[3]

        if self.hparams.exp.tsa:
            scores = torch.softmax(logits.detach(), dim=-1)
            max_probs, preds = torch.max(scores, dim=-1)
            mask = (max_probs.le(self.tsa.threshold) |
                    (labels != preds)).float()
            loss = (F.cross_entropy(
                logits,
                labels,
                reduction='none',
                weight=self.cross_entropy_weights.type_as(logits)) *
                    mask).mean()
        else:
            loss = F.cross_entropy(
                logits,
                labels,
                weight=self.cross_entropy_weights.type_as(logits))

        result = TrainResult(minimize=loss)
        result.log('train_loss',
                   loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log_dict(self.calculate_metrics(logits.detach(),
                                               labels,
                                               prefix='train'),
                        on_epoch=True,
                        on_step=False,
                        sync_dist=True)
        if self.hparams.exp.tsa:
            result.log('train_l_mask',
                       mask.float().mean(),
                       on_epoch=True,
                       on_step=False,
                       sync_dist=True)
        return result
Exemple #4
0
    def training_step(self, composite_batch, batch_idx):
        # Batch collation
        l_batch = composite_batch[0][0][0]
        ul_branches = [[torch.cat(sub_item) for sub_item in zip(*item)]
                       for item in zip(*composite_batch[1])]

        # Supervised loss
        l_logits = self(l_batch)
        l_labels = l_batch[3]
        if self.hparams.exp.tsa:
            l_scores = torch.softmax(l_logits.detach(), dim=-1)
            l_max_probs, l_preds = torch.max(l_scores, dim=-1)
            l_mask = (l_max_probs.le(self.tsa.threshold) |
                      (l_labels != l_preds)).float()
            l_loss = (F.cross_entropy(
                l_logits,
                l_labels,
                reduction='none',
                weight=self.cross_entropy_weights.type_as(l_logits)) *
                      l_mask).mean()
        else:
            l_loss = F.cross_entropy(
                l_logits,
                l_labels,
                reduction='mean',
                weight=self.cross_entropy_weights.type_as(l_logits))

        # Unsupervised loss
        # Choosing pseudo-labels and branches to back-propagate
        u_max_probs_2d = torch.empty(
            (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss)
        u_targets_2d = torch.empty(
            (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss)
        with torch.no_grad():
            for i, ul_branch in enumerate(ul_branches):
                u_logits = self(ul_branch)
                pseudo_labels = torch.softmax(u_logits.detach(), dim=-1)
                u_max_probs_2d[i], u_targets_2d[i] = torch.max(pseudo_labels,
                                                               dim=-1)

        if self.hparams.model.tsa_as_threshold:
            u_mask_2d = u_max_probs_2d.ge(
                torch.tensor(self.tsa.threshold)**(1 / 3)).int()
        else:
            u_mask_2d = u_max_probs_2d.ge(self.hparams.model.threshold).int()

        u_mask = (
            u_mask_2d.sum(dim=0) > 0
        )  # Threshold u_mask per instance, at least one branch should pass the threshold
        u_max_probs, u_best_branches = torch.max(u_max_probs_2d, dim=0)

        u_loss = torch.tensor(0.0).type_as(l_loss)
        u_batch = []
        u_targets = []
        if u_mask.int().sum() > 0:
            # Creating one batch for unlabelled loss
            for i in range(len(ul_branches[0][0])):
                if u_mask[i]:
                    nonmax_branches = [
                        ul_branch
                        for (ind, ul_branch) in enumerate(ul_branches)
                        if ind != u_best_branches[i] and (
                            u_targets_2d[ind, i] != u_targets_2d[
                                u_best_branches[i], i] or not self.hparams.
                            model.choose_only_wrongly_predicted_branches)
                    ]
                    if len(nonmax_branches) > 0:
                        u_batch.extend([[item[i] for item in branch]
                                        for branch in nonmax_branches])
                        u_targets.extend(
                            u_targets_2d[u_best_branches[i]][i].repeat(
                                len(nonmax_branches)))

            # Cutting huge batches
            if self.hparams.model.max_ul_batch_size_per_gpu is not None:
                u_batch = u_batch[:self.hparams.model.
                                  max_ul_batch_size_per_gpu]
                u_targets = u_targets[:self.hparams.model.
                                      max_ul_batch_size_per_gpu]

            if len(u_batch) > 0:
                u_batch = [torch.stack(item) for item in zip(*u_batch)]
                u_targets = torch.stack(u_targets).long()

                # Unlabelled loss
                u_logits = self(u_batch)
                u_loss = F.cross_entropy(
                    u_logits,
                    u_targets,
                    reduction='mean',
                    weight=self.cross_entropy_weights.type_as(u_logits))

        # Train loss / labelled accuracy
        loss = l_loss + self.hparams.model.lambda_u * u_loss

        result = TrainResult(minimize=loss)
        result.log('train_loss',
                   loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_loss_l',
                   l_loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_loss_ul',
                   u_loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_u_mask',
                   u_mask.float().mean(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_u_batch_size',
                   torch.tensor(len(u_targets)).float(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        if self.hparams.exp.tsa:
            result.log('tsa_threshold',
                       torch.tensor(self.tsa.threshold),
                       on_epoch=True,
                       on_step=False,
                       sync_dist=True)
            result.log('train_l_mask',
                       l_mask.float().mean(),
                       on_epoch=True,
                       on_step=False,
                       sync_dist=True)
        result.log_dict(self.calculate_metrics(l_logits.detach(),
                                               l_labels,
                                               prefix='train'),
                        on_epoch=True,
                        on_step=False,
                        sync_dist=True)
        return result