def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): GeneralizedDiceLoss(sigmoid=True, softmax=True) chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""): GeneralizedDiceLoss(reduction="unknown")(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): GeneralizedDiceLoss(reduction=None)(chn_input, chn_target)
def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: common_params = { "include_background": True, "to_onehot_y": False, "reduction": reduction } for focal_weight in [ None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1) ]: for lambda_focal in [0.5, 1.0, 1.5]: generalized_dice_focal = GeneralizedDiceFocalLoss( focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params) generalized_dice = GeneralizedDiceLoss(**common_params) focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params) result = generalized_dice_focal(pred, label) expected_val = generalized_dice( pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val)
def test(): from main import opt net = UNet3D(in_channels=1, out_channels=2) net.to(device) net.load_state_dict(torch.load(opt.load_path)) # Loss function criterion = GeneralizedDiceLoss(to_onehot_y=True, softmax=True) # Test data data_test = BalancedPatchGenerator(opt.path, None, positive_prob=None, shuffle_images=False, mode='test', transform=None, large=opt.large) data_loader_test = DataLoader(data_test, batch_size=opt.batch, num_workers=4) if not os.path.exists('test_results/' + opt.load_path[:-4]): os.makedirs('test_results/' + opt.load_path[:-4]) inference(data_loader_test, net, criterion, opt.load_path[:-4])
def _define_training_loss(self) -> None: """Define the loss function.""" self._training_loss_function = GeneralizedDiceLoss( include_background=True, to_onehot_y=False, softmax=True, batch=True, )
def test_batch(self): prediction = torch.zeros(2, 3, 3, 3) target = torch.zeros(2, 3, 3, 3) prediction.requires_grad = True target.requires_grad = True generalized_dice_loss = GeneralizedDiceLoss(batch=True) loss = generalized_dice_loss(prediction, target) self.assertNotEqual(loss.grad_fn, None)
def test_differentiability(self): prediction = torch.ones((1, 1, 1, 3)) target = torch.ones((1, 1, 1, 3)) prediction.requires_grad = True target.requires_grad = True generalized_dice_loss = GeneralizedDiceLoss() loss = generalized_dice_loss(prediction, target) self.assertNotEqual(loss.grad_fn, None)
def _define_training_loss(self) -> None: """Define the loss function.""" if self._option_loss == SegmentationLosses.GeneralizedDiceLoss: self._training_loss_function = GeneralizedDiceLoss( include_background=True, to_onehot_y=False, softmax=True, batch=True, ) else: raise ValueError(f"Unknown loss option {self._option_loss}! Aborting.")
def test_input_warnings(self): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(include_background=False) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(softmax=True) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = GeneralizedDiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target)
def test_ill_shape(self): loss = GeneralizedDiceLoss() with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))
def test_shape(self, input_param, input_data, expected_val): result = GeneralizedDiceLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
def test_script(self): loss = GeneralizedDiceLoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input)
def main(cfg): """Runs main training procedure.""" print('Starting training...') print('Current working directory is:', os.getcwd()) # fix random seeds for reproducibility seed_everything(seed=cfg['seed']) # neptune logging neptune.init(project_qualified_name=cfg['neptune_project_name'], api_token=cfg['neptune_api_token']) neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg) num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1) activation = 'sigmoid' if num_classes == 1 else 'softmax2d' background = False if cfg['ignore_channels'] else True aux_params = dict( pooling=cfg['pooling'], # one of 'avg', 'max' dropout=cfg['dropout'], # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=num_classes) # define number of output labels # configure model models = { 'unet': Unet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'unetplusplus': UnetPlusPlus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'deeplabv3plus': DeepLabV3Plus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params) } assert cfg['architecture'] in models.keys() model = models[cfg['architecture']] # configure loss losses = { 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']) } assert cfg['loss'] in losses.keys() loss = losses[cfg['loss']] # configure optimizer optimizers = { 'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]), 'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]), 'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])]) } assert cfg['optimizer'] in optimizers.keys() optimizer = optimizers[cfg['optimizer']] # configure metrics metrics = { 'dice_score': DiceMetric(include_background=background, reduction='mean'), 'dice_smp': Fscore(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'iou_smp': IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'cross_entropy': BCELoss(reduction='mean'), 'accuracy': Accuracy(ignore_channels=cfg['ignore_channels']) } assert all(m['name'] in metrics.keys() for m in cfg['metrics']) metrics = [(metrics[m['name']], m['name'], m['type']) for m in cfg['metrics']] # tuple of (metric, name, type) # configure scheduler schedulers = { 'steplr': StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5), 'cosine': CosineAnnealingLR(optimizer, cfg['epochs'], eta_min=cfg['eta_min'], last_epoch=-1) } assert cfg['scheduler'] in schedulers.keys() scheduler = schedulers[cfg['scheduler']] # configure augmentations train_transform = load_train_transform(transform_type=cfg['transform'], patch_size=cfg['patch_size']) valid_transform = load_valid_transform(patch_size=cfg['patch_size']) train_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=train_transform, normalize=cfg['normalize'], tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) valid_dataset = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=cfg['normalize']) # save intermediate augmentations if cfg['eval_dir']: default_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None) transform_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None, tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) for idx in range(0, min(500, len(default_dataset)), 10): image_input, image_mask = default_dataset[idx] image_input = image_input.transpose((1, 2, 0)) image_input = image_input.astype(np.uint8) image_mask = image_mask.transpose( 1, 2, 0) # Why do we need transpose here? image_mask = image_mask.astype(np.uint8) image_mask = image_mask.squeeze() image_mask = image_mask * 255 image_transform, _ = transform_dataset[idx] image_transform = image_transform.transpose( (1, 2, 0)).astype(np.uint8) idx_str = str(idx).zfill(3) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}a_image_input.png'), image_input, check_contrast=False) plt.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}b_image_mask.png'), image_mask, vmin=0, vmax=1) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}c_image_transform.png'), image_transform, check_contrast=False) del transform_dataset train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=False) trainer = Trainer(model=model, device=cfg['device'], save_checkpoints=cfg['save_checkpoints'], checkpoint_dir=cfg['checkpoint_dir'], checkpoint_name=cfg['checkpoint_name']) trainer.compile(optimizer=optimizer, loss=loss, metrics=metrics, num_classes=num_classes) trainer.fit(train_loader, valid_loader, epochs=cfg['epochs'], scheduler=scheduler, verbose=cfg['verbose'], loss_weight=cfg['loss_weight']) # validation inference best_model = model best_model.load_state_dict( torch.load(os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))) best_model.to(cfg['device']) best_model.eval() # setup directory to save plots if os.path.isdir(cfg['plot_dir']): # remove existing dir and content shutil.rmtree(cfg['plot_dir']) # create absolute destination os.makedirs(cfg['plot_dir']) # valid dataset without transformations and normalization for image visualization valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=None) if cfg['save_checkpoints']: for n in range(len(valid_dataset)): image_vis = valid_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose((1, 2, 0)) image, gt_mask = valid_dataset[n] gt_mask = gt_mask.transpose((1, 2, 0)) gt_mask = gt_mask.squeeze() x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = best_model.predict(x_tensor) pr_mask = pr_mask.cpu().numpy().round() pr_mask = pr_mask.squeeze() save_predictions(out_path=cfg['plot_dir'], index=n, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask, average='macro')
def test_shape(self, input_param, input_data, expected_val): result = GeneralizedDiceLoss(**input_param).forward(**input_data) self.assertAlmostEqual(result.item(), expected_val, places=5)
def main(cfg): """Runs main training procedure.""" # fix random seeds for reproducibility seed_everything(seed=cfg['seed']) # neptune logging neptune.init(project_qualified_name=cfg['neptune_project_name'], api_token=cfg['neptune_api_token']) neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg) print('Preparing model and data...') print('Using SMP version:', smp.__version__) num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1) activation = 'sigmoid' if num_classes == 1 else 'softmax2d' background = False if cfg['ignore_channels'] else True binary = True if num_classes == 1 else False softmax = False if num_classes == 1 else True sigmoid = True if num_classes == 1 else False aux_params = dict( pooling=cfg['pooling'], # one of 'avg', 'max' dropout=cfg['dropout'], # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=num_classes) # define number of output labels # configure model models = { 'unet': Unet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'pspnet': PSPNet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params), 'pan': PAN(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params), 'deeplabv3plus': DeepLabV3Plus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params) } assert cfg['architecture'] in models.keys() model = models[cfg['architecture']] # configure loss losses = { 'dice_loss': DiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']) } assert cfg['loss'] in losses.keys() loss = losses[cfg['loss']] # configure optimizer optimizers = { 'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]), 'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]), 'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])]) } assert cfg['optimizer'] in optimizers.keys() optimizer = optimizers[cfg['optimizer']] # configure metrics metrics = { 'dice_score': DiceMetric(include_background=background, reduction='mean'), 'dice_smp': Fscore(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'iou_smp': IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'dice_loss': DiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'cross_entropy': BCELoss(reduction='mean'), 'accuracy': Accuracy(ignore_channels=cfg['ignore_channels']) } assert all(m['name'] in metrics.keys() for m in cfg['metrics']) metrics = [(metrics[m['name']], m['name'], m['type']) for m in cfg['metrics']] # tuple of (metric, name, type) # TODO: Fix metric names # configure scheduler schedulers = { 'steplr': StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5), 'cosine': CosineAnnealingLR(optimizer, cfg['epochs'], eta_min=cfg['eta_min'], last_epoch=-1) } assert cfg['scheduler'] in schedulers.keys() scheduler = schedulers[cfg['scheduler']] # configure augmentations train_transform = load_train_transform(transform_type=cfg['transform'], patch_size=cfg['patch_size_train']) valid_transform = load_valid_transform( patch_size=cfg['patch_size_valid']) # manually selected patch size train_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=train_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) valid_dataset = ArtifactDataset(df_path=cfg['valid_data'], classes=cfg['classes'], transform=valid_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) test_dataset = ArtifactDataset(df_path=cfg['test_data'], classes=cfg['classes'], transform=valid_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) # load pre-sampled patch arrays train_image, train_mask = train_dataset[0] valid_image, valid_mask = valid_dataset[0] print('Shape of image patch', train_image.shape) print('Shape of mask patch', train_mask.shape) print('Train dataset shape:', len(train_dataset)) print('Valid dataset shape:', len(valid_dataset)) assert train_image.shape[1] == cfg[ 'patch_size_train'] and train_image.shape[2] == cfg['patch_size_train'] assert valid_image.shape[1] == cfg[ 'patch_size_valid'] and valid_image.shape[2] == cfg['patch_size_valid'] # save intermediate augmentations if cfg['eval_dir']: default_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=None, normalize=None, ink_filters=cfg['ink_filters']) transform_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=train_transform, normalize=None, ink_filters=cfg['ink_filters']) for idx in range(0, min(500, len(train_dataset)), 10): image_input, image_mask = default_dataset[idx] image_input = image_input.transpose((1, 2, 0)).astype(np.uint8) image_mask = image_mask.transpose(1, 2, 0) image_mask = np.argmax( image_mask, axis=2) if not binary else image_mask.squeeze() image_mask = image_mask.astype(np.uint8) image_transform, _ = transform_dataset[idx] image_transform = image_transform.transpose( (1, 2, 0)).astype(np.uint8) idx_str = str(idx).zfill(3) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}a_image_input.png'), image_input, check_contrast=False) plt.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}b_image_mask.png'), image_mask, vmin=0, vmax=6, cmap='Spectral') skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}c_image_transform.png'), image_transform, check_contrast=False) del transform_dataset # update process print('Starting training...') print('Available GPUs for training:', torch.cuda.device_count()) # pytorch module wrapper class DataParallelModule(torch.nn.DataParallel): def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.module, name) # data parallel training if torch.cuda.device_count() > 1: model = DataParallelModule(model) train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=int(cfg['batch_size'] / 4), num_workers=cfg['workers'], shuffle=False) test_loader = DataLoader(test_dataset, batch_size=int(cfg['batch_size'] / 4), num_workers=cfg['workers'], shuffle=False) trainer = Trainer(model=model, device=cfg['device'], save_checkpoints=cfg['save_checkpoints'], checkpoint_dir=cfg['checkpoint_dir'], checkpoint_name=cfg['checkpoint_name']) trainer.compile(optimizer=optimizer, loss=loss, metrics=metrics, num_classes=num_classes) trainer.fit(train_loader, valid_loader, epochs=cfg['epochs'], scheduler=scheduler, verbose=cfg['verbose'], loss_weight=cfg['loss_weight'], test_loader=test_loader, binary=binary) # validation inference model.load_state_dict( torch.load(os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))) model.to(cfg['device']) model.eval() # save best checkpoint to neptune neptune.log_artifact( os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name'])) # setup directory to save plots if os.path.isdir(cfg['plot_dir_valid']): shutil.rmtree(cfg['plot_dir_valid']) os.makedirs(cfg['plot_dir_valid'], exist_ok=True) # valid dataset without transformations and normalization for image visualization valid_dataset_vis = ArtifactDataset(df_path=cfg['valid_data'], classes=cfg['classes'], ink_filters=cfg['ink_filters']) # keep track of valid masks valid_preds = [] valid_masks = [] if cfg['save_checkpoints']: print('Predicting valid patches...') for n in range(len(valid_dataset)): image_vis = valid_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose(1, 2, 0) image, gt_mask = valid_dataset[n] gt_mask = gt_mask.transpose(1, 2, 0) gt_mask = np.argmax(gt_mask, axis=2) if not binary else gt_mask.squeeze() gt_mask = gt_mask.astype(np.uint8) valid_masks.append(gt_mask) x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = model.predict(x_tensor) pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round() pr_mask = pr_mask.transpose(1, 2, 0) pr_mask = np.argmax(pr_mask, axis=2) if not binary else pr_mask.squeeze() pr_mask = pr_mask.astype(np.uint8) valid_preds.append(pr_mask) save_predictions(out_path=cfg['plot_dir_valid'], index=n + 1, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask) del train_dataset, valid_dataset del train_loader, valid_loader # calculate dice per class valid_masks = np.stack(valid_masks, axis=0) valid_masks = valid_masks.flatten() valid_preds = np.stack(valid_preds, axis=0) valid_preds = valid_preds.flatten() dice_score = f1_score(y_true=valid_masks, y_pred=valid_preds, average=None) neptune.log_text('valid_dice_class', str(dice_score)) print('Valid dice score (class):', str(dice_score)) if cfg['evaluate_test_set']: print('Predicting test patches...') # setup directory to save plots if os.path.isdir(cfg['plot_dir_test']): shutil.rmtree(cfg['plot_dir_test']) os.makedirs(cfg['plot_dir_test'], exist_ok=True) # test dataset without transformations and normalization for image visualization test_dataset_vis = ArtifactDataset(df_path=cfg['test_data'], classes=cfg['classes'], ink_filters=cfg['ink_filters']) # keep track of test masks test_masks = [] test_preds = [] for n in range(len(test_dataset)): image_vis = test_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose(1, 2, 0) image, gt_mask = test_dataset[n] gt_mask = gt_mask.transpose(1, 2, 0) gt_mask = np.argmax(gt_mask, axis=2) if not binary else gt_mask.squeeze() gt_mask = gt_mask.astype(np.uint8) test_masks.append(gt_mask) x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = model.predict(x_tensor) pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round() pr_mask = pr_mask.transpose(1, 2, 0) pr_mask = np.argmax(pr_mask, axis=2) if not binary else pr_mask.squeeze() pr_mask = pr_mask.astype(np.uint8) test_preds.append(pr_mask) save_predictions(out_path=cfg['plot_dir_test'], index=n + 1, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask) # calculate dice per class test_masks = np.stack(test_masks, axis=0) test_masks = test_masks.flatten() test_preds = np.stack(test_preds, axis=0) test_preds = test_preds.flatten() dice_score = f1_score(y_true=test_masks, y_pred=test_preds, average=None) neptune.log_text('test_dice_class', str({dice_score})) print('Test dice score (class):', str(dice_score)) # end of training process print('Finished training!')