Example #1
0
def plot_col(fig, col, img_idx, device, models, edge_width, color_gt, color_model):
    datamodule = util.load_datamodule_for_model(models[0], batch_size=1)
    x, ys = load_sample(datamodule, idx=img_idx, device=device)
    y_mean = torch.stack(ys).float().mean(dim=0).round().long()

    # plot image
    fig.plot_img(0, col, add_edge(x[0], width=edge_width), vmin=0, vmax=1)
    # plot gt seg outline
    fig.plot_contour(0, col, y_mean[0],
                     contour_class=1, width=2, rgba=color_gt)

    # plot model predictions
    for row, model in enumerate(models):
        pl.seed_everything(42)
        with torch.no_grad():
            p = model.pixel_wise_probabaility(x, sample_cnt=16)
            _, y_pred_mean = p.max(dim=1, keepdim=True)
            uncertainty = util.entropy(p)

        # plot uncertainty heatmap
        fig.plot_overlay(
            row + 1, col, add_edge(uncertainty[0], c=1, width=edge_width), alpha=1, vmin=0, vmax=1, cmap='Greys')

        # plot gt seg outline
        fig.plot_contour(
            row + 1, col, y_mean[0], contour_class=1, width=2, rgba=color_gt)

        # plot model prediction outline
        fig.plot_contour(row + 1, col, y_pred_mean[0], contour_class=1, width=2, rgba=color_model
                         )
Example #2
0
 def pixel_wise_uncertainty(self, x, sample_cnt=16):
     """return the pixel-wise entropy
     Args:
         x: the input
         sample_cnt (optional): Amount of samples to draw for internal approximation
     Returns:
         tensor: B x 1 x H x W
     """
     p = self.pixel_wise_probabaility(x, sample_cnt=sample_cnt)
     return util.entropy(p)
Example #3
0
def validate(val_loader,
             model,
             criterion,
             avu_criterion,
             epoch,
             tb_writer=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    avg_unc = AverageMeter()
    global opt_th

    # switch to evaluate mode
    model.eval()

    end = time.time()
    preds_list = []
    labels_list = []
    unc_list = []
    th_list = np.linspace(0, 1, 21)
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            if torch.cuda.is_available():
                target = target.cuda()
                input_var = input.cuda()
                target_var = target.cuda()
            else:
                target = target.cpu()
                input_var = input.cpu()
                target_var = target.cpu()

            if args.half:
                input_var = input_var.half()

            output, kl = model(input_var)
            probs_ = torch.nn.functional.softmax(output, dim=1)
            probs = probs_.data.cpu().numpy()

            pred_entropy = util.entropy(probs)
            unc = np.mean(pred_entropy, axis=0)
            preds = np.argmax(probs, axis=-1)
            preds_list.append(preds)
            labels_list.append(target.cpu().data.numpy())
            unc_list.append(pred_entropy)

            cross_entropy_loss = criterion(output, target_var)
            scaled_kl = kl.data / len_trainset
            elbo_loss = cross_entropy_loss + scaled_kl
            avu_loss, auc_avu = avu_criterion(output, target_var, type=0)
            avu_loss = beta * avu_loss
            loss = cross_entropy_loss + scaled_kl + avu_loss

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            avg_unc.update(unc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1,
                          avg_unc=avg_unc))

            if tb_writer is not None:
                tb_writer.add_scalar('val/cross_entropy_loss',
                                     cross_entropy_loss.item(), epoch)
                tb_writer.add_scalar('val/kl_div', scaled_kl.item(), epoch)
                tb_writer.add_scalar('val/elbo_loss', elbo_loss.item(), epoch)
                tb_writer.add_scalar('val/avu_loss', avu_loss, epoch)
                tb_writer.add_scalar('val/loss', loss.item(), epoch)
                tb_writer.add_scalar('val/AUC-AvU', auc_avu, epoch)
                tb_writer.add_scalar('val/accuracy', prec1.item(), epoch)
                tb_writer.flush()

        preds = np.hstack(np.asarray(preds_list))
        labels = np.hstack(np.asarray(labels_list))
        unc_ = np.hstack(np.asarray(unc_list))
        avu_th, unc_th = util.eval_avu(preds, labels, unc_)
        print('max AvU: ', np.amax(avu_th))
        unc_correct = np.take(unc_, np.where(preds == labels))
        unc_incorrect = np.take(unc_, np.where(preds != labels))
        print('avg unc correct preds: ',
              np.mean(np.take(unc_, np.where(preds == labels)), axis=1))
        print('avg unc incorrect preds: ',
              np.mean(np.take(unc_, np.where(preds != labels)), axis=1))
        '''
        print('unc @max AvU: ', unc_th[np.argmax(avu_th)])
        print('avg unc: ', np.mean(unc_, axis=0))
        print('avg unc: ', np.mean(unc_th, axis=0))
        print('min unc: ', np.amin(unc_))
        print('max unc: ', np.amax(unc_))
        '''
        if epoch <= 5:
            opt_th = (np.mean(unc_correct, axis=1) +
                      np.mean(unc_incorrect, axis=1)) / 2

    print('opt_th: ', opt_th)
    print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
    return top1.avg
Example #4
0
def train(train_loader,
          model,
          criterion,
          avu_criterion,
          optimizer,
          epoch,
          tb_writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    avg_unc = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            target = target.cuda()
            input_var = input.cuda()
        else:
            target = target.cpu()
            input_var = input.cpu()
        target_var = target
        if args.half:
            input_var = input_var.half()

        optimizer.zero_grad()

        output, kl = model(input_var)
        probs_ = torch.nn.functional.softmax(output, dim=1)
        probs = probs_.data.cpu().numpy()

        pred_entropy = util.entropy(probs)
        unc = np.mean(pred_entropy, axis=0)
        preds = np.argmax(probs, axis=-1)

        cross_entropy_loss = criterion(output, target_var)
        scaled_kl = kl.data / len_trainset
        elbo_loss = cross_entropy_loss + scaled_kl
        avu_loss, auc_avu = avu_criterion(output, target_var, type=0)
        avu_loss = beta * avu_loss
        loss = cross_entropy_loss + scaled_kl + avu_loss

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        avg_unc.update(unc, input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            #print('opt_th: ', opt_th)
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Avg_Unc {avg_unc.val:.3f} ({avg_unc.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      avg_unc=avg_unc))

        if tb_writer is not None:
            tb_writer.add_scalar('train/cross_entropy_loss',
                                 cross_entropy_loss.item(), epoch)
            tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
            tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch)
            tb_writer.add_scalar('train/avu_loss', avu_loss, epoch)
            tb_writer.add_scalar('train/loss', loss.item(), epoch)
            tb_writer.add_scalar('train/AUC-AvU', auc_avu, epoch)
            tb_writer.add_scalar('train/accuracy', prec1.item(), epoch)
            tb_writer.flush()
def make_fig(args):
    # set up
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = util.load_model_from_checkpoint(args.model_path).to(device)
    datamodule = util.load_datamodule_for_model(model, batch_size=1)

    for idx in tqdm(range(100), desc='Generating Images'):
        x, y = load_sample(datamodule, idx=idx, device=device)
        fig = Fig(
            rows=1,
            cols=2,
            title=None,
            figsize=None,
            background=True,
        )
        colors = np.array(sns.color_palette("Paired")) * 255
        color_gt = colors[1]
        color_model = colors[7]

        # draw samples
        pl.seed_everything(42)
        with torch.no_grad():
            p = model.pixel_wise_probabaility(x, sample_cnt=args.samples)
            _, y_pred = p.max(dim=1, keepdim=True)
            uncertainty = util.entropy(p)

        # plot image
        fig.plot_img(0, 0, x[0], vmin=0, vmax=1)

        # plot uncertainty heatmap
        fig.plot_overlay(
            0, 1, uncertainty[0], alpha=1, vmin=None, vmax=None, cmap='Greys', colorbar=True, colorbar_label="Model Uncertainty")

        # plot model prediction outline
        fig.plot_contour(0, 0, y_pred[0], contour_class=1, width=2, rgba=color_model
                         )
        fig.plot_contour(0, 1, y_pred[0], contour_class=1, width=2, rgba=color_model
                         )

        # plot gt seg outline
        fig.plot_contour(0, 0, y[0], contour_class=1, width=2, rgba=color_gt
                         )
        fig.plot_contour(0, 1, y[0], contour_class=1, width=2, rgba=color_gt
                         )

        # add legend
        from matplotlib import pyplot as plt
        from matplotlib.patches import Rectangle
        legend_data = [
            [0, color_gt, "GT Annotation"],
            [1, color_model, "Model Prediction"], ]
        handles = [
            Rectangle((0, 0), 1, 1, color=[v/255 for v in c]) for k, c, n in legend_data
        ]
        labels = [n for k, c, n in legend_data]

        plt.legend(handles, labels, ncol=len(legend_data))

        os.makedirs("./plots/", exist_ok=True)
        # fig.save(args.output_file)
        os.makedirs(args.output_folder, exist_ok=True)
        fig.save(os.path.join(args.output_folder,
                              f'test_{idx}.png'), close=False)
        fig.save(os.path.join(args.output_folder, f'test_{idx}.pdf'))
def train(train_loader, model, criterion, avu_criterion, optimizer, epoch,
          args, tb_writer):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    global opt_th
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
        else:
            images = images.cpu(non_blocking=True)
            target = target.cpu(non_blocking=True)

        # compute output
        output, kl = model(images)
        probs_ = torch.nn.functional.softmax(output, dim=1)
        probs = probs_.data.cpu().numpy()

        pred_entropy = util.entropy(probs)
        preds = np.argmax(probs, axis=-1)
        AvU = util.accuracy_vs_uncertainty(np.array(preds),
                                           np.array(target.cpu().data.numpy()),
                                           np.array(pred_entropy), opt_th)

        preds_list.append(preds)
        labels_list.append(target.cpu().data.numpy())
        unc_list.append(pred_entropy)

        cross_entropy_loss = criterion(output, target)
        scaled_kl = (kl.data[0] / len_trainset)
        elbo_loss = cross_entropy_loss + scaled_kl
        avu_loss = beta * avu_criterion(output, target, opt_th, type=0)
        loss = cross_entropy_loss + scaled_kl + avu_loss

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

        if tb_writer is not None:
            tb_writer.add_scalar('train/cross_entropy_loss',
                                 cross_entropy_loss.item(), epoch)
            tb_writer.add_scalar('train/kl_div', scaled_kl.item(), epoch)
            tb_writer.add_scalar('train/elbo_loss', elbo_loss.item(), epoch)
            tb_writer.add_scalar('train/avu_loss', avu_loss.item(), epoch)
            tb_writer.add_scalar('train/loss', loss.item(), epoch)
            tb_writer.add_scalar('train/AvU', AvU, epoch)
            tb_writer.add_scalar('train/accuracy', acc1.item(), epoch)
            tb_writer.flush()

    preds = np.hstack(np.asarray(preds_list))
    labels = np.hstack(np.asarray(labels_list))
    unc_ = np.hstack(np.asarray(unc_list))
    unc_correct = np.take(unc_, np.where(preds == labels))
    unc_incorrect = np.take(unc_, np.where(preds != labels))
    #print('avg unc correct preds: ', np.mean(np.take(unc_,np.where(preds == labels)), axis=1))
    #print('avg unc incorrect preds: ', np.mean(np.take(unc_,np.where(preds != labels)), axis=1))
    if epoch <= 1:
        opt_th = (np.mean(unc_correct, axis=1) +
                  np.mean(unc_incorrect, axis=1)) / 2

    print('opt_th: ', opt_th)