def test(model, loader, device): model.eval() all_labels = [] all_logits = [] all_predictions = [] all_losses = [] all_seg_preds_interp = [] all_dices = [] all_ious = [] evaluator = Evaluator(ex.current_run.config['model']['num_classes']) image_evaluator = Evaluator(ex.current_run.config['model']['num_classes']) pbar = tqdm(loader, ncols=80, desc='Test') with torch.no_grad(): for image, segmentation, label in pbar: image = image.to(device) logits = model(image).cpu() pred = model.pooling.predictions(logits=logits).item() loss = model.pooling.loss(logits=logits, labels=label) if ex.current_run.config['dataset']['name'] == 'caltech_birds': segmentation_classes = (segmentation.squeeze() > 0.5) else: segmentation_classes = (segmentation.squeeze() != 0) seg_logits = model.pooling.cam seg_logits_interp = F.interpolate(seg_logits, size=segmentation_classes.shape, mode='bilinear', align_corners=True).squeeze(0) label = label.item() all_labels.append(label) all_logits.append(logits) all_predictions.append(pred) all_losses.append(loss.item()) if ex.current_run.config['dataset']['name'] == 'glas': if ex.current_run.config['model'][ 'pooling'] == 'deepmil_multi': seg_preds_interp = (seg_logits_interp[label] > (1 / seg_logits.numel())).cpu() else: seg_preds_interp = ( seg_logits_interp.argmax(0) == label).cpu() else: if ex.current_run.config['model']['pooling'] == 'deepmil': seg_preds_interp = (seg_logits_interp.squeeze(0) > (1 / seg_logits.numel())).cpu() elif ex.current_run.config['model'][ 'pooling'] == 'deepmil_multi': seg_preds_interp = (seg_logits_interp[label] > (1 / seg_logits.numel())).cpu() else: seg_preds_interp = seg_logits_interp.argmax(0).cpu() # all_seg_probs_interp.append(seg_probs_interp.numpy()) all_seg_preds_interp.append( seg_preds_interp.numpy().astype('bool')) evaluator.add_batch(segmentation_classes, seg_preds_interp) image_evaluator.add_batch(segmentation_classes, seg_preds_interp) all_dices.append(image_evaluator.dice()[1].item()) all_ious.append( image_evaluator.intersection_over_union()[1].item()) image_evaluator.reset() all_logits = torch.cat(all_logits, 0) all_probabilities = model.pooling.probabilities(all_logits) metrics = metric_report(np.array(all_labels), all_probabilities.numpy(), np.array(all_predictions)) metrics['images_path'] = loader.dataset.samples metrics['labels'] = np.array(all_labels) metrics['logits'] = all_logits.numpy() metrics['probabilities'] = all_probabilities.numpy() metrics['predictions'] = np.array(all_predictions) metrics['losses'] = np.array(all_losses) metrics['dice_per_image'] = np.array(all_dices) metrics['mean_dice'] = metrics['dice_per_image'].mean() metrics['dice'] = evaluator.dice()[1].item() metrics['iou_per_image'] = np.array(all_ious) metrics['mean_iou'] = metrics['iou_per_image'].mean() metrics['iou'] = evaluator.intersection_over_union()[1].item() if ex.current_run.config['dataset'][ 'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0: metrics['seg_preds'] = all_seg_preds_interp return metrics
def test(model, loader, device): model.eval() all_labels = [] all_logits = [] all_predictions = [] all_losses = [] all_seg_logits_interp = [] all_seg_preds_interp = [] all_dices = [] all_ious = [] evaluator = Evaluator(ex.current_run.config['model']['num_classes']) image_evaluator = Evaluator(ex.current_run.config['model']['num_classes']) pbar = tqdm(loader, ncols=80, desc='Test') pooling = ex.current_run.config['model']['pooling'] if pooling in requires_gradients: grad_policy = torch.set_grad_enabled(True) else: grad_policy = torch.no_grad() is_ae = isinstance(model.backbone, ResNet_AE) with grad_policy: for i, (image, segmentation, label) in enumerate(pbar): image, label = image.to(device), label.to(device) if pooling in requires_gradients or pooling == 'ablation': model.pooling.eval_cams = True if is_ae: z, x_reconst = model.backbone(image) logits = model.pooling(z) else: logits = model(image) pred = model.pooling.predictions(logits=logits).item() loss = model.pooling.loss(logits=logits, labels=label) if ex.current_run.config['dataset']['name'] == 'caltech_birds': segmentation_classes = (segmentation.squeeze() > 0.5) else: segmentation_classes = (segmentation.squeeze() != 0) seg_logits = model.pooling.cam.detach().cpu() seg_logits_interp = F.interpolate(seg_logits, size=segmentation_classes.shape, mode='bilinear', align_corners=True).squeeze(0) label = label.item() all_labels.append(label) all_logits.append(logits.cpu()) all_predictions.append(pred) all_losses.append(loss.item()) if ex.current_run.config['dataset']['name'] == 'glas': if ex.current_run.config['model'][ 'pooling'] == 'deepmil_multi': seg_preds_interp = (seg_logits_interp[label] > (1 / seg_logits.numel())).cpu() else: seg_preds_interp = ( seg_logits_interp.argmax(0) == label).cpu() else: if ex.current_run.config['model']['pooling'] == 'deepmil': seg_preds_interp = (seg_logits_interp.squeeze(0) > (1 / seg_logits.numel())).cpu() elif ex.current_run.config['model'][ 'pooling'] == 'deepmil_multi': seg_preds_interp = (seg_logits_interp[label] > (1 / seg_logits.numel())).cpu() else: seg_preds_interp = seg_logits_interp.argmax(0).cpu() # Save CAMs visualization save_dir = 'cams/{}/{}'.format( ex.current_run.config['model']['arch'] + str(ex.current_run.config['balance']), ex.current_run.config['model']['pooling']) os.makedirs(save_dir, exist_ok=True) file_path = os.path.join(save_dir, 'cam_{}.png'.format(i)) seg_logits_interp_norm = seg_logits_interp / seg_logits_interp.max( ) saliency_map_0, overlay_0 = visualize_cam( seg_logits_interp_norm[0], image) saliency_map_1, overlay_1 = visualize_cam( seg_logits_interp_norm[1], image) overlay = [overlay_0, overlay_1][label] save_visualization(image.squeeze().cpu(), segmentation_classes.numpy(), saliency_map_0, saliency_map_1, overlay, seg_preds_interp.numpy() * 255, label, file_path) if is_ae: x_reconst = x_reconst.detach() save_dir = 'reconst/{}/{}'.format( ex.current_run.config['model']['arch'], ex.current_run.config['model']['pooling']) os.makedirs(save_dir, exist_ok=True) file_path = os.path.join(save_dir, 'reconst_{}.png'.format(i)) save_reconst( image.squeeze(0).cpu(), x_reconst.squeeze(0).cpu(), file_path) all_seg_logits_interp.append(seg_logits_interp.numpy()) all_seg_preds_interp.append( seg_preds_interp.numpy().astype('bool')) evaluator.add_batch(segmentation_classes, seg_preds_interp) image_evaluator.add_batch(segmentation_classes, seg_preds_interp) all_dices.append(image_evaluator.dice()[1].item()) all_ious.append( image_evaluator.intersection_over_union()[1].item()) image_evaluator.reset() if pooling in requires_gradients or pooling == 'ablation': model.pooling.eval_cams = False all_logits = torch.cat(all_logits, 0) all_logits = all_logits.detach() all_probabilities = model.pooling.probabilities(all_logits) with open('test/gradcampp_seg_preds.pkl', 'wb') as f: pkl.dump(all_seg_preds_interp, f) results_dir = 'out/{}/{}'.format( ex.current_run.config['model']['arch'] + str(ex.current_run.config['balance']), ex.current_run.config['model']['pooling']) save_results(results_dir, loader.dataset.samples, np.array(all_labels), np.array(all_predictions), all_seg_logits_interp, all_seg_preds_interp, np.array(all_dices)) metrics = metric_report(np.array(all_labels), all_probabilities.numpy(), np.array(all_predictions)) metrics['images_path'] = loader.dataset.samples metrics['labels'] = np.array(all_labels) metrics['logits'] = all_logits.numpy() metrics['probabilities'] = all_probabilities.numpy() metrics['predictions'] = np.array(all_predictions) metrics['losses'] = np.array(all_losses) metrics['dice_per_image'] = np.array(all_dices) metrics['mean_dice'] = metrics['dice_per_image'].mean() metrics['dice'] = evaluator.dice()[1].item() metrics['iou_per_image'] = np.array(all_ious) metrics['mean_iou'] = metrics['iou_per_image'].mean() metrics['iou'] = evaluator.intersection_over_union()[1].item() metrics['conf_mat'] = evaluator.cm.numpy() if ex.current_run.config['dataset'][ 'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0: metrics['seg_preds'] = all_seg_preds_interp return metrics
def test_bma(self, dataset, n_predictions): self.model.eval() loader = DataLoader(dataset, batch_size=1, shuffle=False) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') with torch.no_grad(): if n_predictions == 1: return self.evaluate(loader=loader, test=True) else: all_labels = [] all_losses = [] all_seg_preds = [] all_dices = [] all_ious = [] evaluator = Evaluator(2) image_evaluator = Evaluator(2) for image, mask, lbl, name in tqdm(loader, ncols=80, desc='Test MC predictions'): image, mask = image.to(device), mask.squeeze(1).to( device, non_blocking=True) class_masks = (mask != 0).long() # MC predictions preds = [self.model(image) for _ in range(n_predictions)] stack = torch.stack(preds, dim=-1) seg_logits = stack.mean(dim=-1) # loss = F.cross_entropy(seg_logits, class_masks).item() # # dice_loss = DiceLoss() # # loss = dice_loss(seg_logits, class_masks).item() loss = self.get_loss(seg_logits, class_masks).item() seg_preds = seg_logits.argmax(1) evaluator.add_batch(class_masks, seg_preds) image_evaluator.add_batch(class_masks, seg_preds) dices = image_evaluator.dice() ious = image_evaluator.intersection_over_union() image_evaluator.reset() all_labels.append(lbl[0]) all_losses.append(loss) all_dices.append(dices.cpu()) all_ious.append(ious.cpu()) all_seg_preds.append( seg_preds.squeeze(0).byte().cpu().numpy().astype( 'bool')) all_labels = np.array(all_labels) all_losses = np.array(all_losses) all_dices = torch.stack(all_dices, 0) all_ious = torch.stack(all_ious, 0) dices = evaluator.dice() ious = evaluator.intersection_over_union() metrics = { 'images_path': loader.dataset.rows, 'labels': all_labels, 'losses': all_losses, 'dice_background_per_image': all_dices[:, 0].numpy(), 'mean_dice_background': all_dices[:, 0].numpy().mean(), 'dice_background': dices[0].item(), 'dice_per_image': all_dices[:, 1].numpy(), 'mean_dice': all_dices[:, 1].numpy().mean(), 'dice': dices[1].item(), 'iou_background_per_image': all_ious[:, 0].numpy(), 'mean_iou_background': all_ious[:, 0].numpy().mean(), 'iou_background': ious[0].item(), 'iou_per_image': all_ious[:, 1].numpy(), 'mean_iou': all_ious[:, 1].numpy().mean(), 'iou': ious[1].item(), } return metrics
def evaluate(self, loader, test=False): self.model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') all_labels = [] all_losses = [] all_seg_preds = [] all_dices = [] all_ious = [] evaluator = Evaluator(2) image_evaluator = Evaluator(2) pbar = tqdm(loader, ncols=80, desc='Test' if test else 'Validation') with torch.no_grad(): for image, mask, label, f_name in pbar: image, mask = image.to(device), mask.squeeze(1).to( device, non_blocking=True) class_masks = (mask != 0).long() seg_logits = self.model(image) # loss = F.cross_entropy(seg_logits, class_masks).item() # # dice_loss = DiceLoss() # # loss = dice_loss(seg_logits, class_masks).item() loss = self.get_loss(seg_logits, class_masks).item() seg_preds = seg_logits.argmax(1) evaluator.add_batch(class_masks, seg_preds) image_evaluator.add_batch(class_masks, seg_preds) dices = image_evaluator.dice() ious = image_evaluator.intersection_over_union() image_evaluator.reset() all_labels.append(label[0]) all_losses.append(loss) all_dices.append(dices.cpu()) all_ious.append(ious.cpu()) all_seg_preds.append( seg_preds.squeeze(0).byte().cpu().numpy().astype('bool')) all_labels = np.array(all_labels) all_losses = np.array(all_losses) all_dices = torch.stack(all_dices, 0) all_ious = torch.stack(all_ious, 0) dices = evaluator.dice() ious = evaluator.intersection_over_union() metrics = { 'images_path': loader.dataset.rows, 'labels': all_labels, 'losses': all_losses, 'dice_background_per_image': all_dices[:, 0].numpy(), 'mean_dice_background': all_dices[:, 0].numpy().mean(), 'dice_background': dices[0].item(), 'dice_per_image': all_dices[:, 1].numpy(), 'mean_dice': all_dices[:, 1].numpy().mean(), 'dice': dices[1].item(), 'iou_background_per_image': all_ious[:, 0].numpy(), 'mean_iou_background': all_ious[:, 0].numpy().mean(), 'iou_background': ious[0].item(), 'iou_per_image': all_ious[:, 1].numpy(), 'mean_iou': all_ious[:, 1].numpy().mean(), 'iou': ious[1].item(), } return metrics
def evaluate(model, loader, device, test=False): model.eval() all_labels = [] all_losses = [] all_seg_preds = [] all_dices = [] all_ious = [] evaluator = Evaluator(2) image_evaluator = Evaluator(2) pbar = tqdm(loader, ncols=80, desc='Test' if test else 'Validation') with torch.no_grad(): for image, mask, label in pbar: image = image.to(device, non_blocking=True) segmentation = (mask != 0).squeeze(1) t_segmentation = segmentation.to(device, non_blocking=True).long() seg_logits = model(image) loss = F.cross_entropy(seg_logits, t_segmentation).item() seg_probs = torch.softmax(seg_logits, 1) seg_preds = seg_logits.argmax(1) evaluator.add_batch(t_segmentation, seg_preds) image_evaluator.add_batch(t_segmentation, seg_preds) dices = image_evaluator.dice() ious = image_evaluator.intersection_over_union() image_evaluator.reset() all_labels.append(label.item()) all_losses.append(loss) all_dices.append(dices.cpu()) all_ious.append(ious.cpu()) all_seg_preds.append( seg_preds.squeeze(0).byte().cpu().numpy().astype('bool')) all_labels = np.array(all_labels) all_losses = np.array(all_losses) all_dices = torch.stack(all_dices, 0) all_ious = torch.stack(all_ious, 0) dices = evaluator.dice() ious = evaluator.intersection_over_union() metrics = { 'images_path': loader.dataset.samples, 'labels': all_labels, 'losses': all_losses, 'dice_background_per_image': all_dices[:, 0].numpy(), 'mean_dice_background': all_dices[:, 0].numpy().mean(), 'dice_background': dices[0].item(), 'dice_per_image': all_dices[:, 1].numpy(), 'mean_dice': all_dices[:, 1].numpy().mean(), 'dice': dices[1].item(), 'iou_background_per_image': all_ious[:, 0].numpy(), 'mean_iou_background': all_ious[:, 0].numpy().mean(), 'iou_background': ious[0].item(), 'iou_per_image': all_ious[:, 1].numpy(), 'mean_iou': all_ious[:, 1].numpy().mean(), 'iou': ious[1].item(), } if test and ex.current_run.config['dataset'][ 'split'] == 0 and ex.current_run.config['dataset']['fold'] == 0: metrics['seg_preds'] = all_seg_preds return metrics
def main(epochs, seed): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') cudnn.deterministic = True torch.manual_seed(seed) train_loader, valid_loader, test_loader = load_dataset() model = load_unet() model = torch.nn.DataParallel(model) model.to(device) optimizer, scheduler = get_optimizer_scheduler( parameters=model.parameters()) train_losses = AverageMeter() batch_evaluator = Evaluator(2) train_dices_background = AverageMeter() train_dices = AverageMeter() best_valid_dice = 0 best_valid_loss = float('inf') best_model_dict = deepcopy(model.module.state_dict()) for epoch in range(epochs): model.train() train_losses.reset(), train_dices_background.reset( ), train_dices.reset() loader_length = len(train_loader) pbar = tqdm(train_loader, ncols=80, desc='Training') start = time.time() for i, (images, mask, label) in enumerate(pbar): images, mask = images.to(device), mask.squeeze(1).to( device, non_blocking=True) seg_logits = model(images) if ex.current_run.config['dataset']['name'] == 'caltech_birds': class_mask = (mask > 0.5).long() else: class_mask = (mask != 0).long() loss = F.cross_entropy(seg_logits, class_mask) optimizer.zero_grad() loss.backward() optimizer.step() batch_evaluator.add_batch(class_mask, seg_logits.argmax(1)) dices = batch_evaluator.dice() dice_background = dices[0].item() dice = dices[1].item() batch_evaluator.reset() loss = loss.item() step = epoch + i / loader_length ex.log_scalar('training.loss', loss, step) ex.log_scalar('training.mean_dice_background', dice_background, step) ex.log_scalar('training.mean_dice', dice, step) train_losses.append(loss) train_dices_background.append(dice_background) train_dices.append(dice) scheduler.step() duration = time.time() - start # evaluate on validation set valid_metrics = evaluate(model=model, loader=valid_loader, device=device) if valid_metrics['losses'].mean() <= best_valid_loss: best_valid_dice = valid_metrics['mean_dice'] best_valid_loss = valid_metrics['losses'].mean() best_model_dict = deepcopy(model.module.state_dict()) ex.log_scalar('validation.loss', np.mean(valid_metrics['losses']), epoch + 1) ex.log_scalar('validation.mean_dice', valid_metrics['mean_dice'], epoch + 1) print('Epoch {:02d} | Duration: {:.1f}s - per batch ({}): {:.3f}s'. format(epoch, duration, loader_length, duration / loader_length)) print( ' ' * 8, '| Train loss: {:.4f} dice(b): {:.3f} dice: {:.3f}'.format( train_losses.avg, train_dices_background.avg, train_dices.avg)) print( ' ' * 8, '| Valid loss: {:.4f} dice(b): {:.3f} dice: {:.3f}'.format( valid_metrics['losses'].mean(), valid_metrics['mean_dice_background'], valid_metrics['mean_dice'])) # load best model based on validation loss model = load_unet() model.load_state_dict(best_model_dict) model.to(device) # evaluate on test set test_metrics = evaluate(model=model, loader=test_loader, device=device, test=True) ex.log_scalar('test.loss', test_metrics['losses'].mean(), epochs) ex.log_scalar('test.mean_dice_background', test_metrics['mean_dice_background'], epochs) ex.log_scalar('test.mean_dice', test_metrics['mean_dice'], epochs) # save model save_name = get_save_name() + '.pickle' torch.save(state_dict_to_cpu(best_model_dict), save_name) ex.add_artifact(os.path.abspath(save_name)) # save test metrics if len(ex.current_run.observers) > 0: dataset = ex.current_run.config['dataset']['name'] split = ex.current_run.config['dataset']['split'] fold = ex.current_run.config['dataset']['fold'] torch.save( test_metrics, os.path.join( ex.current_run.observers[0].dir, '{}_unet_split-{}_fold-{}.pkl'.format(dataset, split, fold))) # metrics to info.json info_to_save = [ 'labels', 'losses', 'dice_per_image', 'mean_dice', 'dice', 'iou_per_image', 'mean_iou', 'iou' ] for k in info_to_save: ex.info[k] = test_metrics[k] return test_metrics['mean_dice']