Пример #1
0
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop
        :param batch:
        :return:
        """
        try:
            # forward pass
            inputs, gt_labels, paths = batch
            self.crt_batch_idx = batch_idx
            self.inputs = inputs
            if self.cfg.mixup.enable:
                inputs, gt_labels_a, gt_labels_b, lam = mixup_data(
                    inputs, gt_labels, self.cfg.mixup.alpha)
                mixup_y = [gt_labels_a, gt_labels_b, lam]
            predictions = self.forward(inputs)

            # calculate loss
            if self.cfg.mixup.enable:
                loss_val = mixup_loss_fn(self.loss, predictions, *mixup_y)
            else:
                loss_val = self.loss(predictions, gt_labels)

            # acc
            acc_results = topk_acc(predictions, gt_labels, self.cfg.topk)
            tqdm_dict = {}

            if self.on_gpu:
                acc_results = [
                    torch.tensor(x).to(loss_val.device.index)
                    for x in acc_results
                ]

            # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
            if self.trainer.use_dp or self.trainer.use_ddp2:
                loss_val = loss_val.unsqueeze(0)
                acc_results = [x.unsqueeze(0) for x in acc_results]

            tqdm_dict['train_loss'] = loss_val
            for i, k in enumerate(self.cfg.topk):
                tqdm_dict[f'train_acc_{k}'] = acc_results[i]

            output = OrderedDict({
                'loss': loss_val,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })

            self.train_meters.update(
                {key: val.item()
                 for key, val in tqdm_dict.items()})

            # can also return just a scalar instead of a dict (return loss_val)
            return output
        except Exception as e:
            print(str(e))
            print(batch_idx, paths)
            pass
Пример #2
0
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop
        :param batch:
        :return:
        """

        # forward pass
        inputs, gt_labels = batch
        predictions = self.forward(inputs)

        # calculate loss
        loss_val = self.loss(predictions, gt_labels)

        # acc
        acc_results = topk_acc(predictions, gt_labels, self.cfg.topk)
        tqdm_dict = {}

        if self.on_gpu:
            acc_results = [
                torch.tensor(x).to(loss_val.device.index) for x in acc_results
            ]

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            acc_results = [x.unsqueeze(0) for x in acc_results]

        tqdm_dict['train_loss'] = loss_val
        for i, k in enumerate(self.cfg.topk):
            tqdm_dict[f'train_acc_{k}'] = acc_results[i]

        output = OrderedDict({
            'loss': loss_val,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })

        self.train_meters.update(
            {key: val.item()
             for key, val in tqdm_dict.items()})
        # self.print_log(batch_idx, True, inputs, self.train_meters)

        # can also return just a scalar instead of a dict (return loss_val)
        return output
Пример #3
0
    def validation_step(self, batch, batch_idx, test=False):
        """
        Lightning calls this inside the validation loop
        :param batch:
        :return:
        """
        inputs, gt_labels, paths = batch
        self.inputs = inputs
        predictions = self.forward(inputs)
        _, pred = torch.max(predictions, dim=1)

        loss_val = self.loss(predictions, gt_labels)

        # acc
        val_acc_1, val_acc_k = topk_acc(predictions, gt_labels, self.cfg.topk)

        if self.on_gpu:
            val_acc_1 = val_acc_1.cuda(loss_val.device.index)
            val_acc_k = val_acc_k.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc_1 = val_acc_1.unsqueeze(0)
            val_acc_k = val_acc_k.unsqueeze(0)

        output = OrderedDict({
            'valid_loss': torch.tensor(loss_val),
            'valid_acc_1': torch.tensor(val_acc_1),
            f'valid_acc_{self.cfg.topk[-1]}': val_acc_k,
        })
        tqdm_dict = {k: v for k, v in dict(output).items()}
        self.valid_meters.update(
            {key: val.item()
             for key, val in tqdm_dict.items()})
        # self.print_log(batch_idx, False, inputs, self.valid_meters)

        if self.cfg.module.analyze_result:
            output.update({
                'predictions': predictions.detach(),
                'gt_labels': gt_labels.detach(),
            })
        # can also return just a scalar instead of a dict (return loss_val)
        return output