Example #1
0
        def training_step(self, batch, batch_idx, hiddens):
            assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
            self.test_hidden = torch.rand(1)

            x_tensor, y_list = batch
            assert x_tensor.shape[
                1] == truncated_bptt_steps, "tbptt split Tensor failed"

            y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
            assert y_tensor.shape[
                1] == truncated_bptt_steps, "tbptt split list failed"

            pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
            loss_val = torch.nn.functional.mse_loss(
                pred, y_tensor.view(batch_size, truncated_bptt_steps))

            result = TrainResult(loss_val, hiddens=self.test_hidden)
            return result
    def training_step_result_log_epoch_and_step(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)
        result = TrainResult(minimize=acc)

        val_1 = (5 + batch_idx) * (self.current_epoch + 1)
        val_2 = (6 + batch_idx) * (self.current_epoch + 1)
        val_3 = (7 + batch_idx) * (self.current_epoch + 1)
        result.log('step_epoch_log_and_pbar_acc1',
                   torch.tensor(val_1).type_as(acc),
                   on_epoch=True,
                   prog_bar=True)
        result.log('step_epoch_log_acc2',
                   torch.tensor(val_2).type_as(acc),
                   on_epoch=True)
        result.log('step_epoch_pbar_acc3',
                   torch.tensor(val_3).type_as(acc),
                   on_epoch=True,
                   logger=False,
                   prog_bar=True)

        self.training_step_called = True
        return result
    def training_step_result_log_epoch_only(self, batch, batch_idx):
        acc = self.step(batch, batch_idx)
        result = TrainResult(minimize=acc)

        result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}',
                   torch.tensor(14).type_as(acc),
                   on_epoch=True,
                   prog_bar=True,
                   on_step=False)
        result.log(f'epoch_log_acc2_e{self.current_epoch}',
                   torch.tensor(15).type_as(acc),
                   on_epoch=True,
                   on_step=False)
        result.log(f'epoch_pbar_acc3_e{self.current_epoch}',
                   torch.tensor(16).type_as(acc),
                   on_epoch=True,
                   logger=False,
                   prog_bar=True,
                   on_step=False)

        self.training_step_called = True
        return result
Example #4
0
    def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None):
        # forward pass
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        # calculate loss
        loss_val = self.loss(y, y_hat)
        log_val = loss_val

        # alternate between tensors and scalars for "log" and "progress_bar"
        if batch_idx % 2 == 0:
            log_val = log_val.item()

        result = TrainResult(loss_val)
        result.log('some_val', log_val * log_val, prog_bar=True, logger=False)
        result.log('train_some_val', log_val * log_val)
        return result
Example #5
0
    def training_step(self, composite_batch, batch_idx):
        # Batch collation
        l_batch = composite_batch[0][0][0]
        ul_branches = [[torch.cat(sub_item) for sub_item in zip(*item)]
                       for item in zip(*composite_batch[1])]

        # Supervised loss
        l_logits = self(l_batch)
        l_labels = l_batch[3]
        if self.hparams.exp.tsa:
            l_scores = torch.softmax(l_logits.detach(), dim=-1)
            l_max_probs, l_preds = torch.max(l_scores, dim=-1)
            l_mask = (l_max_probs.le(self.tsa.threshold) |
                      (l_labels != l_preds)).float()
            l_loss = (F.cross_entropy(
                l_logits,
                l_labels,
                reduction='none',
                weight=self.cross_entropy_weights.type_as(l_logits)) *
                      l_mask).mean()
        else:
            l_loss = F.cross_entropy(
                l_logits,
                l_labels,
                reduction='mean',
                weight=self.cross_entropy_weights.type_as(l_logits))

        # Unsupervised loss
        # Choosing pseudo-labels and branches to back-propagate
        u_max_probs_2d = torch.empty(
            (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss)
        u_targets_2d = torch.empty(
            (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss)
        with torch.no_grad():
            for i, ul_branch in enumerate(ul_branches):
                u_logits = self(ul_branch)
                pseudo_labels = torch.softmax(u_logits.detach(), dim=-1)
                u_max_probs_2d[i], u_targets_2d[i] = torch.max(pseudo_labels,
                                                               dim=-1)

        if self.hparams.model.tsa_as_threshold:
            u_mask_2d = u_max_probs_2d.ge(
                torch.tensor(self.tsa.threshold)**(1 / 3)).int()
        else:
            u_mask_2d = u_max_probs_2d.ge(self.hparams.model.threshold).int()

        u_mask = (
            u_mask_2d.sum(dim=0) > 0
        )  # Threshold u_mask per instance, at least one branch should pass the threshold
        u_max_probs, u_best_branches = torch.max(u_max_probs_2d, dim=0)

        u_loss = torch.tensor(0.0).type_as(l_loss)
        u_batch = []
        u_targets = []
        if u_mask.int().sum() > 0:
            # Creating one batch for unlabelled loss
            for i in range(len(ul_branches[0][0])):
                if u_mask[i]:
                    nonmax_branches = [
                        ul_branch
                        for (ind, ul_branch) in enumerate(ul_branches)
                        if ind != u_best_branches[i] and (
                            u_targets_2d[ind, i] != u_targets_2d[
                                u_best_branches[i], i] or not self.hparams.
                            model.choose_only_wrongly_predicted_branches)
                    ]
                    if len(nonmax_branches) > 0:
                        u_batch.extend([[item[i] for item in branch]
                                        for branch in nonmax_branches])
                        u_targets.extend(
                            u_targets_2d[u_best_branches[i]][i].repeat(
                                len(nonmax_branches)))

            # Cutting huge batches
            if self.hparams.model.max_ul_batch_size_per_gpu is not None:
                u_batch = u_batch[:self.hparams.model.
                                  max_ul_batch_size_per_gpu]
                u_targets = u_targets[:self.hparams.model.
                                      max_ul_batch_size_per_gpu]

            if len(u_batch) > 0:
                u_batch = [torch.stack(item) for item in zip(*u_batch)]
                u_targets = torch.stack(u_targets).long()

                # Unlabelled loss
                u_logits = self(u_batch)
                u_loss = F.cross_entropy(
                    u_logits,
                    u_targets,
                    reduction='mean',
                    weight=self.cross_entropy_weights.type_as(u_logits))

        # Train loss / labelled accuracy
        loss = l_loss + self.hparams.model.lambda_u * u_loss

        result = TrainResult(minimize=loss)
        result.log('train_loss',
                   loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_loss_l',
                   l_loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_loss_ul',
                   u_loss.detach(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_u_mask',
                   u_mask.float().mean(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        result.log('train_u_batch_size',
                   torch.tensor(len(u_targets)).float(),
                   on_epoch=True,
                   on_step=False,
                   sync_dist=True)
        if self.hparams.exp.tsa:
            result.log('tsa_threshold',
                       torch.tensor(self.tsa.threshold),
                       on_epoch=True,
                       on_step=False,
                       sync_dist=True)
            result.log('train_l_mask',
                       l_mask.float().mean(),
                       on_epoch=True,
                       on_step=False,
                       sync_dist=True)
        result.log_dict(self.calculate_metrics(l_logits.detach(),
                                               l_labels,
                                               prefix='train'),
                        on_epoch=True,
                        on_step=False,
                        sync_dist=True)
        return result