Exemplo n.º 1
0
 def __init__(self):
     super(SoftmaxHeteroscedasticLoss, self).__init__()
     self.adf_softmax = adf.Softmax(dim=1,
                                    keep_variance_fn=keep_variance_fn)
Exemplo n.º 2
0
    def train_epoch(self,
                    train_loader,
                    model,
                    criterion,
                    optimizer,
                    epoch,
                    evaluator,
                    scheduler,
                    color_fn,
                    report=10,
                    show_scans=False):
        losses = AverageMeter()
        acc = AverageMeter()
        iou = AverageMeter()
        hetero_l = AverageMeter()
        update_ratio_meter = AverageMeter()

        # empty the cache to train now
        if self.gpu:
            torch.cuda.empty_cache()

        # switch to train mode
        model.train()

        end = time.time()
        for i, (in_vol, proj_mask, proj_labels, _, path_seq, path_name, _, _,
                _, _, _, _, _, _, _) in enumerate(train_loader):
            # measure data loading time
            self.data_time_t.update(time.time() - end)
            if not self.multi_gpu and self.gpu:
                in_vol = in_vol.cuda()
                #proj_mask = proj_mask.cuda()
            if self.gpu:
                proj_labels = proj_labels.cuda().long()

            # compute output
            if self.uncertainty:
                output = model(in_vol)
                output_mean, output_var = adf.Softmax(
                    dim=1, keep_variance_fn=keep_variance_fn)(*output)
                hetero = self.SoftmaxHeteroscedasticLoss(output, proj_labels)
                loss_m = criterion(output_mean.clamp(min=1e-8),
                                   proj_labels) + hetero + self.ls(
                                       output_mean, proj_labels.long())

                hetero_l.update(hetero.mean().item(), in_vol.size(0))
                output = output_mean
            else:
                output = model(in_vol)
                loss_m = criterion(torch.log(output.clamp(min=1e-8)),
                                   proj_labels) + self.ls(
                                       output, proj_labels.long())

            optimizer.zero_grad()
            if self.n_gpus > 1:
                idx = torch.ones(self.n_gpus).cuda()
                loss_m.backward(idx)
            else:
                loss_m.backward()
            optimizer.step()

            # measure accuracy and record loss
            loss = loss_m.mean()
            with torch.no_grad():
                evaluator.reset()
                argmax = output.argmax(dim=1)
                evaluator.addBatch(argmax, proj_labels)
                accuracy = evaluator.getacc()
                jaccard, class_jaccard = evaluator.getIoU()

            losses.update(loss.item(), in_vol.size(0))
            acc.update(accuracy.item(), in_vol.size(0))
            iou.update(jaccard.item(), in_vol.size(0))

            # measure elapsed time
            self.batch_time_t.update(time.time() - end)
            end = time.time()

            # get gradient updates and weights, so I can print the relationship of
            # their norms
            update_ratios = []
            for g in self.optimizer.param_groups:
                lr = g["lr"]
                for value in g["params"]:
                    if value.grad is not None:
                        w = np.linalg.norm(value.data.cpu().numpy().reshape(
                            (-1)))
                        update = np.linalg.norm(
                            -max(lr, 1e-10) * value.grad.cpu().numpy().reshape(
                                (-1)))
                        update_ratios.append(update / max(w, 1e-10))
            update_ratios = np.array(update_ratios)
            update_mean = update_ratios.mean()
            update_std = update_ratios.std()
            update_ratio_meter.update(update_mean)  # over the epoch

            if show_scans:
                # get the first scan in batch and project points
                mask_np = proj_mask[0].cpu().numpy()
                depth_np = in_vol[0][0].cpu().numpy()
                pred_np = argmax[0].cpu().numpy()
                gt_np = proj_labels[0].cpu().numpy()
                out = Trainer.make_log_img(depth_np, mask_np, pred_np, gt_np,
                                           color_fn)

                mask_np = proj_mask[1].cpu().numpy()
                depth_np = in_vol[1][0].cpu().numpy()
                pred_np = argmax[1].cpu().numpy()
                gt_np = proj_labels[1].cpu().numpy()
                out2 = Trainer.make_log_img(depth_np, mask_np, pred_np, gt_np,
                                            color_fn)

                out = np.concatenate([out, out2], axis=0)
                cv2.imshow("sample_training", out)
                cv2.waitKey(1)
            if self.uncertainty:

                if i % self.ARCH["train"]["report_batch"] == 0:
                    print(
                        'Lr: {lr:.3e} | '
                        'Update: {umean:.3e} mean,{ustd:.3e} std | '
                        'Epoch: [{0}][{1}/{2}] | '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) | '
                        'Loss {loss.val:.4f} ({loss.avg:.4f}) | '
                        'Hetero {hetero_l.val:.4f} ({hetero_l.avg:.4f}) | '
                        'acc {acc.val:.3f} ({acc.avg:.3f}) | '
                        'IoU {iou.val:.3f} ({iou.avg:.3f}) | [{estim}]'.format(
                            epoch,
                            i,
                            len(train_loader),
                            batch_time=self.batch_time_t,
                            data_time=self.data_time_t,
                            loss=losses,
                            hetero_l=hetero_l,
                            acc=acc,
                            iou=iou,
                            lr=lr,
                            umean=update_mean,
                            ustd=update_std,
                            estim=self.calculate_estimate(epoch, i)))

                    save_to_log(
                        self.log, 'log.txt', 'Lr: {lr:.3e} | '
                        'Update: {umean:.3e} mean,{ustd:.3e} std | '
                        'Epoch: [{0}][{1}/{2}] | '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) | '
                        'Loss {loss.val:.4f} ({loss.avg:.4f}) | '
                        'Hetero {hetero.val:.4f} ({hetero.avg:.4f}) | '
                        'acc {acc.val:.3f} ({acc.avg:.3f}) | '
                        'IoU {iou.val:.3f} ({iou.avg:.3f}) | [{estim}]'.format(
                            epoch,
                            i,
                            len(train_loader),
                            batch_time=self.batch_time_t,
                            data_time=self.data_time_t,
                            loss=losses,
                            hetero=hetero_l,
                            acc=acc,
                            iou=iou,
                            lr=lr,
                            umean=update_mean,
                            ustd=update_std,
                            estim=self.calculate_estimate(epoch, i)))
            else:
                if i % self.ARCH["train"]["report_batch"] == 0:
                    print(
                        'Lr: {lr:.3e} | '
                        'Update: {umean:.3e} mean,{ustd:.3e} std | '
                        'Epoch: [{0}][{1}/{2}] | '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) | '
                        'Loss {loss.val:.4f} ({loss.avg:.4f}) | '
                        'acc {acc.val:.3f} ({acc.avg:.3f}) | '
                        'IoU {iou.val:.3f} ({iou.avg:.3f}) | [{estim}]'.format(
                            epoch,
                            i,
                            len(train_loader),
                            batch_time=self.batch_time_t,
                            data_time=self.data_time_t,
                            loss=losses,
                            acc=acc,
                            iou=iou,
                            lr=lr,
                            umean=update_mean,
                            ustd=update_std,
                            estim=self.calculate_estimate(epoch, i)))

                    save_to_log(
                        self.log, 'log.txt', 'Lr: {lr:.3e} | '
                        'Update: {umean:.3e} mean,{ustd:.3e} std | '
                        'Epoch: [{0}][{1}/{2}] | '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) | '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f}) | '
                        'Loss {loss.val:.4f} ({loss.avg:.4f}) | '
                        'acc {acc.val:.3f} ({acc.avg:.3f}) | '
                        'IoU {iou.val:.3f} ({iou.avg:.3f}) | [{estim}]'.format(
                            epoch,
                            i,
                            len(train_loader),
                            batch_time=self.batch_time_t,
                            data_time=self.data_time_t,
                            loss=losses,
                            acc=acc,
                            iou=iou,
                            lr=lr,
                            umean=update_mean,
                            ustd=update_std,
                            estim=self.calculate_estimate(epoch, i)))

            # step scheduler
            scheduler.step()

        return acc.avg, iou.avg, losses.avg, update_ratio_meter.avg, hetero_l.avg