def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): DiceLoss(sigmoid=True, softmax=True) chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""): DiceLoss(reduction="unknown")(chn_input, chn_target) with self.assertRaisesRegex(ValueError, ""): DiceLoss(reduction=None)(chn_input, chn_target)
def test_result_onehot_target_include_bg(self): size = [3, 3, 5, 5] label = torch.randint(low=0, high=2, size=size) pred = torch.randn(size) for reduction in ["sum", "mean", "none"]: common_params = { "include_background": True, "to_onehot_y": False, "reduction": reduction } for focal_weight in [ None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1) ]: for lambda_focal in [0.5, 1.0, 1.5]: dice_focal = DiceFocalLoss(focal_weight=focal_weight, gamma=1.0, lambda_focal=lambda_focal, **common_params) dice = DiceLoss(**common_params) focal = FocalLoss(weight=focal_weight, gamma=1.0, **common_params) result = dice_focal(pred, label) expected_val = dice( pred, label) + lambda_focal * focal(pred, label) np.testing.assert_allclose(result, expected_val)
def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")): class _TestBatch(Dataset): def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) return im[None], seg[None].astype(np.float32) def __len__(self): return train_steps net = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-4) src = DataLoader(_TestBatch(), batch_size=batch_size) trainer = create_supervised_trainer(net, opt, loss, device, False) trainer.run(src, 1) loss = trainer.state.output return loss
def test_epistemic_scoring(self): input_size = (20, 20, 20) device = "cuda" if torch.cuda.is_available() else "cpu" keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) transforms = Compose([ AddChanneld(keys), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ]) infer_transforms = Compose([ AddChannel(), CropForeground(), DivisiblePad(4), ]) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(3, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to( device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) entropy_score = EpistemicScoring(model=model, transforms=infer_transforms, roi_size=[20, 20, 20], num_samples=10) # Call Individual Infer from Epistemic Scoring ip_stack = [test_data["image"], test_data["image"], test_data["image"]] ip_stack = np.array(ip_stack) score_3d = entropy_score.entropy_3d_volume(ip_stack) score_3d_sum = np.sum(score_3d) # Call Entropy Metric from Epistemic Scoring self.assertEqual(score_3d.shape, input_size) self.assertIsInstance(score_3d_sum, np.float32) self.assertGreater(score_3d_sum, 3.0)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")): class _TestBatch(Dataset): def __init__(self, transforms): self.transforms = transforms def __getitem__(self, _unused_id): im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1) seed = np.random.randint(2147483647) self.transforms.set_random_state(seed=seed) im = self.transforms(im) self.transforms.set_random_state(seed=seed) seg = self.transforms(seg) return im, seg def __len__(self): return train_steps net = UNet( dimensions=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2, ).to(device) loss = DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-2) train_transforms = Compose([ AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(), ToTensor() ]) src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size, shuffle=True) net.train() epoch_loss = 0 step = 0 for img, seg in src: step += 1 opt.zero_grad() output = net(img.to(device)) step_loss = loss(output, seg.to(device)) step_loss.backward() opt.step() epoch_loss += step_loss.item() epoch_loss /= step return epoch_loss, step
def __init__(self, focal): super(Loss, self).__init__() self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True) self.focal = FocalLoss(gamma=2.0) self.cross_entropy = nn.CrossEntropyLoss() self.use_focal = focal
def __init__(self, model: nn.Module, lr: float = 1e-4): """Creates a simple UNet for segmenting nodules in chest CTs. Defines its training and validation logic. Args: model: (nn.Module): Model to use for training. lr (float, optional): Initial learning rate. Defaults to 1e-4. """ super().__init__() self.model = model self.loss = DiceLoss(to_onehot_y=True, softmax=True) self.lr = lr self.save_hyperparameters() return
def training_step(self, batch, batch_idx): inputs, targets = self.prepare_batch(batch) pred = self(inputs) # diceloss = DiceLoss(include_background=True, to_onehot_y=True) # loss = diceloss.forward(input=probs, target=targets) # dice, iou, _, _ = get_score(batch_preds, batch_targets, include_background=True) # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True) # loss = gdloss.forward(input=batch_preds, target=batch_targets) # if batch_idx != 0 and ((self.current_epoch >= 1 and dice.item() < 0.5) or batch_idx % 100 == 0): # input = inputs.chunk(inputs.size()[0], 0)[0] # split into 1 in the dimension 0 # target = targets.chunk(targets.size()[0], 0)[0] # split into 1 in the dimension 0 # prob = probs.chunk(probs.size()[0], 0)[0] # split into 1 in the dimension 0 # # really have problem in there, need to fix it # dice_score, _, _, _ = get_score(torch.unsqueeze(prob, 0), torch.unsqueeze(target, 0)) # log_all_info(self, input, target, prob, batch_idx, "training", dice_score.item()) # loss = F.binary_cross_entropy_with_logits(logits, targets) diceloss = DiceLoss(include_background=self.hparams.include_background, to_onehot_y=True) loss = diceloss.forward(input=pred, target=targets) # What is the loos I need to set here? when I am using the batch size? # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True) # loss = gdloss.forward(input=batch_preds, target=batch_targets) # I cannot use this `TrainResult` right now # the loss for prog_bar is not corrected, is there anything I write wrong? result = pl.TrainResult(minimize=loss) # logs metrics for each training_step, to the progress bar and logger result.log("train_loss", loss, prog_bar=True, sync_dist=True, logger=True, reduce_fx=torch.mean, on_step=True, on_epoch=False) # we cannot compute the matrixs on the patches, because they do not contain all the 138 segmentations # So they would return 0 on some of the classes, making the matrixs not accurate return result
test_loader = DataLoader(test_ds, batch_size=1, num_workers=0) """## Create Model, Loss, Optimizer""" # standard PyTorch program style: create UNet, DiceLoss and Adam optimizer #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') model = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.Adam(model.parameters(), 1e-4) """## Execute a typical PyTorch training process""" epoch_num = 2 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) for epoch in range(epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}")
val_loader = DataLoader(val_ds, batch_size=1) # Create Model, Loss, Optimizer # standard PyTorch program style: create UNet, DiceLoss and Adam optimizer device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = UNet( dimensions=3, # in_channels=1, # out_channels=1, # channels=(16, 32, 64, 128, 256), # strides=(2, 2, 2, 2), # num_res_units=2, norm=Norm.BATCH, ).to(device) # loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # Execute a typical PyTorch training process max_epochs = 50 # max_epochs = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] post_pred = AsDiscrete(threshold_values=True, n_classes=1) post_label = AsDiscrete(n_classes=1) for epoch in range(max_epochs):
def __init__(self, focal): super(LossBraTS, self).__init__() self.dice = DiceLoss(sigmoid=True, batch=True) self.ce = FocalLoss( gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
def __init__( self, include_background: bool = False, to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, other_act: Optional[Callable] = None, squared_pred: bool = False, jaccard: bool = False, reduction: str = "mean", smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, topK: int = 10, ce_weight: Optional[torch.Tensor] = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, ) -> None: """ Args: ``reduction`` is used for both losses and other parameters are only used for dice loss. include_background: if False channel index 0 (background category) is excluded from the calculation. to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`. softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`. other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`. squared_pred: use squared versions of targets and predictions in the denominator or not. jaccard: compute Jaccard Index (soft IoU) instead of dice or not. reduction: {``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. The dice loss should as least reduce the spatial dimensions, which is different from cross entropy loss, thus here the ``none`` option cannot be used. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. topK: k of top-k. ce_weight: a rescaling weight given to each class for cross entropy loss. lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. Defaults to 1.0. """ super().__init__() self.dice = DiceLoss( include_background=include_background, to_onehot_y=to_onehot_y, sigmoid=sigmoid, softmax=softmax, other_act=other_act, squared_pred=squared_pred, jaccard=jaccard, reduction=reduction, smooth_nr=smooth_nr, smooth_dr=smooth_dr, batch=batch, ) self.cross_entropy = nn.CrossEntropyLoss( weight=ce_weight, reduction='none', ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_ce < 0.0: raise ValueError("lambda_ce should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_ce = lambda_ce self.k = topK
def main(cfg): """Runs main training procedure.""" # fix random seeds for reproducibility seed_everything(seed=cfg['seed']) # neptune logging neptune.init(project_qualified_name=cfg['neptune_project_name'], api_token=cfg['neptune_api_token']) neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg) print('Preparing model and data...') print('Using SMP version:', smp.__version__) num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1) activation = 'sigmoid' if num_classes == 1 else 'softmax2d' background = False if cfg['ignore_channels'] else True binary = True if num_classes == 1 else False softmax = False if num_classes == 1 else True sigmoid = True if num_classes == 1 else False aux_params = dict( pooling=cfg['pooling'], # one of 'avg', 'max' dropout=cfg['dropout'], # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=num_classes) # define number of output labels # configure model models = { 'unet': Unet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'pspnet': PSPNet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params), 'pan': PAN(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params), 'deeplabv3plus': DeepLabV3Plus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params) } assert cfg['architecture'] in models.keys() model = models[cfg['architecture']] # configure loss losses = { 'dice_loss': DiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']) } assert cfg['loss'] in losses.keys() loss = losses[cfg['loss']] # configure optimizer optimizers = { 'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]), 'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]), 'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])]) } assert cfg['optimizer'] in optimizers.keys() optimizer = optimizers[cfg['optimizer']] # configure metrics metrics = { 'dice_score': DiceMetric(include_background=background, reduction='mean'), 'dice_smp': Fscore(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'iou_smp': IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'dice_loss': DiceLoss(include_background=background, softmax=softmax, sigmoid=sigmoid, batch=cfg['combine']), 'cross_entropy': BCELoss(reduction='mean'), 'accuracy': Accuracy(ignore_channels=cfg['ignore_channels']) } assert all(m['name'] in metrics.keys() for m in cfg['metrics']) metrics = [(metrics[m['name']], m['name'], m['type']) for m in cfg['metrics']] # tuple of (metric, name, type) # TODO: Fix metric names # configure scheduler schedulers = { 'steplr': StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5), 'cosine': CosineAnnealingLR(optimizer, cfg['epochs'], eta_min=cfg['eta_min'], last_epoch=-1) } assert cfg['scheduler'] in schedulers.keys() scheduler = schedulers[cfg['scheduler']] # configure augmentations train_transform = load_train_transform(transform_type=cfg['transform'], patch_size=cfg['patch_size_train']) valid_transform = load_valid_transform( patch_size=cfg['patch_size_valid']) # manually selected patch size train_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=train_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) valid_dataset = ArtifactDataset(df_path=cfg['valid_data'], classes=cfg['classes'], transform=valid_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) test_dataset = ArtifactDataset(df_path=cfg['test_data'], classes=cfg['classes'], transform=valid_transform, normalize=cfg['normalize'], ink_filters=cfg['ink_filters']) # load pre-sampled patch arrays train_image, train_mask = train_dataset[0] valid_image, valid_mask = valid_dataset[0] print('Shape of image patch', train_image.shape) print('Shape of mask patch', train_mask.shape) print('Train dataset shape:', len(train_dataset)) print('Valid dataset shape:', len(valid_dataset)) assert train_image.shape[1] == cfg[ 'patch_size_train'] and train_image.shape[2] == cfg['patch_size_train'] assert valid_image.shape[1] == cfg[ 'patch_size_valid'] and valid_image.shape[2] == cfg['patch_size_valid'] # save intermediate augmentations if cfg['eval_dir']: default_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=None, normalize=None, ink_filters=cfg['ink_filters']) transform_dataset = ArtifactDataset(df_path=cfg['train_data'], classes=cfg['classes'], transform=train_transform, normalize=None, ink_filters=cfg['ink_filters']) for idx in range(0, min(500, len(train_dataset)), 10): image_input, image_mask = default_dataset[idx] image_input = image_input.transpose((1, 2, 0)).astype(np.uint8) image_mask = image_mask.transpose(1, 2, 0) image_mask = np.argmax( image_mask, axis=2) if not binary else image_mask.squeeze() image_mask = image_mask.astype(np.uint8) image_transform, _ = transform_dataset[idx] image_transform = image_transform.transpose( (1, 2, 0)).astype(np.uint8) idx_str = str(idx).zfill(3) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}a_image_input.png'), image_input, check_contrast=False) plt.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}b_image_mask.png'), image_mask, vmin=0, vmax=6, cmap='Spectral') skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}c_image_transform.png'), image_transform, check_contrast=False) del transform_dataset # update process print('Starting training...') print('Available GPUs for training:', torch.cuda.device_count()) # pytorch module wrapper class DataParallelModule(torch.nn.DataParallel): def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.module, name) # data parallel training if torch.cuda.device_count() > 1: model = DataParallelModule(model) train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=int(cfg['batch_size'] / 4), num_workers=cfg['workers'], shuffle=False) test_loader = DataLoader(test_dataset, batch_size=int(cfg['batch_size'] / 4), num_workers=cfg['workers'], shuffle=False) trainer = Trainer(model=model, device=cfg['device'], save_checkpoints=cfg['save_checkpoints'], checkpoint_dir=cfg['checkpoint_dir'], checkpoint_name=cfg['checkpoint_name']) trainer.compile(optimizer=optimizer, loss=loss, metrics=metrics, num_classes=num_classes) trainer.fit(train_loader, valid_loader, epochs=cfg['epochs'], scheduler=scheduler, verbose=cfg['verbose'], loss_weight=cfg['loss_weight'], test_loader=test_loader, binary=binary) # validation inference model.load_state_dict( torch.load(os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))) model.to(cfg['device']) model.eval() # save best checkpoint to neptune neptune.log_artifact( os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name'])) # setup directory to save plots if os.path.isdir(cfg['plot_dir_valid']): shutil.rmtree(cfg['plot_dir_valid']) os.makedirs(cfg['plot_dir_valid'], exist_ok=True) # valid dataset without transformations and normalization for image visualization valid_dataset_vis = ArtifactDataset(df_path=cfg['valid_data'], classes=cfg['classes'], ink_filters=cfg['ink_filters']) # keep track of valid masks valid_preds = [] valid_masks = [] if cfg['save_checkpoints']: print('Predicting valid patches...') for n in range(len(valid_dataset)): image_vis = valid_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose(1, 2, 0) image, gt_mask = valid_dataset[n] gt_mask = gt_mask.transpose(1, 2, 0) gt_mask = np.argmax(gt_mask, axis=2) if not binary else gt_mask.squeeze() gt_mask = gt_mask.astype(np.uint8) valid_masks.append(gt_mask) x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = model.predict(x_tensor) pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round() pr_mask = pr_mask.transpose(1, 2, 0) pr_mask = np.argmax(pr_mask, axis=2) if not binary else pr_mask.squeeze() pr_mask = pr_mask.astype(np.uint8) valid_preds.append(pr_mask) save_predictions(out_path=cfg['plot_dir_valid'], index=n + 1, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask) del train_dataset, valid_dataset del train_loader, valid_loader # calculate dice per class valid_masks = np.stack(valid_masks, axis=0) valid_masks = valid_masks.flatten() valid_preds = np.stack(valid_preds, axis=0) valid_preds = valid_preds.flatten() dice_score = f1_score(y_true=valid_masks, y_pred=valid_preds, average=None) neptune.log_text('valid_dice_class', str(dice_score)) print('Valid dice score (class):', str(dice_score)) if cfg['evaluate_test_set']: print('Predicting test patches...') # setup directory to save plots if os.path.isdir(cfg['plot_dir_test']): shutil.rmtree(cfg['plot_dir_test']) os.makedirs(cfg['plot_dir_test'], exist_ok=True) # test dataset without transformations and normalization for image visualization test_dataset_vis = ArtifactDataset(df_path=cfg['test_data'], classes=cfg['classes'], ink_filters=cfg['ink_filters']) # keep track of test masks test_masks = [] test_preds = [] for n in range(len(test_dataset)): image_vis = test_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose(1, 2, 0) image, gt_mask = test_dataset[n] gt_mask = gt_mask.transpose(1, 2, 0) gt_mask = np.argmax(gt_mask, axis=2) if not binary else gt_mask.squeeze() gt_mask = gt_mask.astype(np.uint8) test_masks.append(gt_mask) x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = model.predict(x_tensor) pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round() pr_mask = pr_mask.transpose(1, 2, 0) pr_mask = np.argmax(pr_mask, axis=2) if not binary else pr_mask.squeeze() pr_mask = pr_mask.astype(np.uint8) test_preds.append(pr_mask) save_predictions(out_path=cfg['plot_dir_test'], index=n + 1, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask) # calculate dice per class test_masks = np.stack(test_masks, axis=0) test_masks = test_masks.flatten() test_preds = np.stack(test_preds, axis=0) test_preds = test_preds.flatten() dice_score = f1_score(y_true=test_masks, y_pred=test_preds, average=None) neptune.log_text('test_dice_class', str({dice_score})) print('Test dice score (class):', str(dice_score)) # end of training process print('Finished training!')
def test_shape(self, input_param, input_data, expected_val): result = DiceLoss(**input_param).forward(**input_data) np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
def test_test_time_augmentation(self): input_size = (20, 40) # test different input data shape to pad list collate keys = ["image", "label"] num_training_ims = 10 train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ AddChanneld(keys), RandAffined( keys, prob=1.0, spatial_size=(30, 30), rotate_range=(np.pi / 3, np.pi / 3), translate_range=(3, 3), scale_range=((0.8, 1), (0.8, 1)), padding_mode="zeros", mode=("bilinear", "nearest"), as_tensor_output=False, ), CropForegroundd(keys, source_key="image"), DivisiblePadd(keys, 4), ] ) train_ds = CacheDataset(train_data, transforms) # output might be different size, so pad so that they match train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) loss_function = DiceLoss(sigmoid=True) optimizer = torch.optim.Adam(model.parameters(), 1e-3) num_epochs = 10 for _ in trange(num_epochs): epoch_loss = 0 for batch_data in train_loader: inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_loss /= len(train_loader) post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( transform=transforms, batch_size=5, num_workers=0, inferrer_fn=model, device=device, to_tensor=True, output_device="cpu", post_func=post_trans, ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) self.assertGreaterEqual(mean.min(), 0.0) self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float)
def configure(self): self.set_device() network = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ).to(self.device) if self.multi_gpu: network = DistributedDataParallel( module=network, device_ids=[self.device], find_unused_parameters=False, ) train_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), Spacingd(keys=("image", "label"), pixdim=[1.0, 1.0, 1.0], mode=["bilinear", "nearest"]), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0, ), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=("image", "label")), ]) train_datalist = load_decathlon_datalist(self.data_list_file_path, True, "training") if self.multi_gpu: train_datalist = partition_dataset( data=train_datalist, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True, )[dist.get_rank()] train_ds = CacheDataset( data=train_datalist, transform=train_transforms, cache_num=32, cache_rate=1.0, num_workers=4, ) train_data_loader = DataLoader( train_ds, batch_size=2, shuffle=True, num_workers=4, ) val_transforms = Compose([ LoadImaged(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), ScaleIntensityRanged( keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=("image", "label"), source_key="image"), ToTensord(keys=("image", "label")), ]) val_datalist = load_decathlon_datalist(self.data_list_file_path, True, "validation") val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4) val_data_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=4, ) post_transform = Compose([ Activationsd(keys="pred", softmax=True), AsDiscreted( keys=["pred", "label"], argmax=[True, False], to_onehot=True, n_classes=2, ), ]) # metric key_val_metric = { "val_mean_dice": MeanDice( include_background=False, output_transform=lambda x: (x["pred"], x["label"]), device=self.device, ) } val_handlers = [ StatsHandler(output_transform=lambda x: None), CheckpointSaver( save_dir=self.ckpt_dir, save_dict={"model": network}, save_key_metric=True, ), TensorBoardStatsHandler(log_dir=self.ckpt_dir, output_transform=lambda x: None), ] self.eval_engine = SupervisedEvaluator( device=self.device, val_data_loader=val_data_loader, network=network, inferer=SlidingWindowInferer( roi_size=[160, 160, 160], sw_batch_size=4, overlap=0.5, ), post_transform=post_transform, key_val_metric=key_val_metric, val_handlers=val_handlers, amp=self.amp, ) optimizer = torch.optim.Adam(network.parameters(), self.learning_rate) loss_function = DiceLoss(to_onehot_y=True, softmax=True) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=self.eval_engine, interval=self.val_interval, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler( log_dir=self.ckpt_dir, tag_name="train_loss", output_transform=lambda x: x["loss"], ), ] self.train_engine = SupervisedTrainer( device=self.device, max_epochs=self.max_epochs, train_data_loader=train_data_loader, network=network, optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, key_train_metric=None, train_handlers=train_handlers, amp=self.amp, ) if self.local_rank > 0: self.train_engine.logger.setLevel(logging.WARNING) self.eval_engine.logger.setLevel(logging.WARNING)
def __init__(self, loss): super().__init__() self.loss = loss self.focal = FocalLoss(gamma=2.0) self.dice_bg = DiceLoss(include_background=True, softmax=True, to_onehot_y=True, batch=True) self.dice_nbg = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
def create_trainer(args): set_determinism(seed=args.seed) multi_gpu = args.multi_gpu local_rank = args.local_rank if multi_gpu: dist.init_process_group(backend="nccl", init_method="env://") device = torch.device("cuda:{}".format(local_rank)) torch.cuda.set_device(device) else: device = torch.device("cuda" if args.use_gpu else "cpu") pre_transforms = get_pre_transforms(args.roi_size, args.model_size, args.dimensions) click_transforms = get_click_transforms() post_transform = get_post_transforms() train_loader, val_loader = get_loaders(args, pre_transforms) # define training components network = get_network(args.network, args.channels, args.dimensions).to(device) if multi_gpu: network = torch.nn.parallel.DistributedDataParallel( network, device_ids=[local_rank], output_device=local_rank) if args.resume: logging.info('{}:: Loading Network...'.format(local_rank)) map_location = {"cuda:0": "cuda:{}".format(local_rank)} network.load_state_dict( torch.load(args.model_filepath, map_location=map_location)) # define event-handlers for engine val_handlers = [ StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(log_dir=args.output, output_transform=lambda x: None), DeepgrowStatsHandler(log_dir=args.output, tag_name='val_dice', image_interval=args.image_interval), CheckpointSaver(save_dir=args.output, save_dict={"net": network}, save_key_metric=True, save_final=True, save_interval=args.save_interval, final_filename='model.pt') ] val_handlers = val_handlers if local_rank == 0 else None evaluator = SupervisedEvaluator( device=device, val_data_loader=val_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_val_interactions, key_probability='probability', train=False), inferer=SimpleInferer(), post_transform=post_transform, key_val_metric={ "val_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, val_handlers=val_handlers) loss_function = DiceLoss(sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(network.parameters(), args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1) train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=args.val_freq, epoch_level=True), StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), TensorBoardStatsHandler(log_dir=args.output, tag_name="train_loss", output_transform=lambda x: x["loss"]), CheckpointSaver(save_dir=args.output, save_dict={ "net": network, "opt": optimizer, "lr": lr_scheduler }, save_interval=args.save_interval * 2, save_final=True, final_filename='checkpoint.pt'), ] train_handlers = train_handlers if local_rank == 0 else train_handlers[:2] trainer = SupervisedTrainer( device=device, max_epochs=args.epochs, train_data_loader=train_loader, network=network, iteration_update=Interaction( transforms=click_transforms, max_interactions=args.max_train_interactions, key_probability='probability', train=True), optimizer=optimizer, loss_function=loss_function, inferer=SimpleInferer(), post_transform=post_transform, amp=args.amp, key_train_metric={ "train_dice": MeanDice(include_background=False, output_transform=lambda x: (x["pred"], x["label"])) }, train_handlers=train_handlers, ) return trainer
def loss_function(self, context: Context): return DiceLoss(sigmoid=True, squared_pred=True)
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)
def train_process(fast=False): epoch_num = 10 val_interval = 1 train_trans, val_trans = transformations() train_ds = Dataset(data=train_files, transform=train_trans) val_ds = Dataset(data=val_files, transform=val_trans) train_loader = DataLoader(train_ds, batch_size=2, shuffle=True) val_loader = DataLoader(val_ds, batch_size=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n1 = 16 model = UNet(dimensions=3, in_channels=1, out_channels=2, channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16), strides=(2, 2, 2, 2)).to(device) loss_function = DiceLoss(to_onehot_y=True, softmax=True) post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) post_label = AsDiscrete(to_onehot=True, n_classes=2) optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5) best_metric = -1 best_metric_epoch = -1 best_metrics_epochs_and_time = [[], [], []] epoch_loss_values = list() metric_values = list() for epoch in range(epoch_num): print(f"epoch {epoch + 1}/{epoch_num}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data['image'].to( device), batch_data['label'].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): metric_sum = 0. metric_count = 0 for val_data in val_loader: val_inputs, val_labels = val_data['image'].to( device), val_data['label'].to(device) val_outputs = model(val_inputs) val_outputs = post_pred(val_outputs) val_labels = post_label(val_labels) value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) metric_count += len(value) metric_sum += value.sum().item() metric = metric_sum / metric_count metric_values.append(metric) if metric > best_metric: best_metric = metric epochs_no_improve = 0 best_metric_epoch = epoch + 1 best_metrics_epochs_and_time[0].append(best_metric) best_metrics_epochs_and_time[1].append(best_metric_epoch) torch.save(model.state_dict(), 'sLUMRTL644.pth') else: epochs_no_improve += 1 print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
def test_shape(self, input_param, input_data, expected_val): result = DiceLoss(**input_param).forward(**input_data) self.assertAlmostEqual(result.item(), expected_val, places=5)
def compute_from_aggregating(self, input, target, if_path: bool, type_as_tensor=None, whether_to_return_img=False, result: pl.EvalResult = None): transform = get_val_transform() if if_path: cur_img_subject = torchio.Subject( img=torchio.Image(input, type=torchio.INTENSITY)) cur_label_subject = torchio.Subject( img=torchio.Image(target, type=torchio.LABEL)) preprocessed_img = transform(cur_img_subject) preprocessed_label = transform(cur_label_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_img, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) for patches_batch in patch_loader: input_tensor = patches_batch['img'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(type_as_tensor['val_dice']) locations = patches_batch[torchio.LOCATION] preds = self(input_tensor) # use cuda labels = preds.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor() # not using cuda! if if_path or whether_to_return_img: return preprocessed_img.img.data, output_tensor, preprocessed_label.img.data else: return output_tensor, preprocessed_label.img.data else: cur_subject = torchio.Subject( img=torchio.Image(tensor=input.squeeze(), type=torchio.INTENSITY), label=torchio.Image(tensor=target.squeeze(), type=torchio.LABEL)) preprocessed_subject = transform(cur_subject) patch_overlap = self.hparams.patch_overlap # is there any constrain? grid_sampler = torchio.inference.GridSampler( preprocessed_subject, self.patch_size, patch_overlap, ) patch_loader = torch.utils.data.DataLoader(grid_sampler) aggregator = torchio.inference.GridAggregator(grid_sampler) dice_loss = [] for patches_batch in patch_loader: input_tensor, target_tensor = patches_batch['img'][ torchio.DATA], patches_batch['label'][torchio.DATA] # used to convert tensor to CUDA input_tensor = input_tensor.type_as(input) locations = patches_batch[torchio.LOCATION] preds_tensor = self(input_tensor) # use cuda # Compute the loss here diceloss = DiceLoss( include_background=self.hparams.include_background, to_onehot_y=True) loss = diceloss.forward(input=preds_tensor, target=target_tensor) dice_loss.append(loss) labels = preds_tensor.argmax(dim=torchio.CHANNELS_DIMENSION, keepdim=True) # use cuda aggregator.add_batch(labels, locations) output_tensor = aggregator.get_output_tensor( ) # not using cuda!!!! if whether_to_return_img: return cur_subject['img'].data, output_tensor, cur_subject[ 'label'].data else: return output_tensor, cur_subject['label'].data, torch.stack( dice_loss)
def main(cfg): """Runs main training procedure.""" print('Starting training...') print('Current working directory is:', os.getcwd()) # fix random seeds for reproducibility seed_everything(seed=cfg['seed']) # neptune logging neptune.init(project_qualified_name=cfg['neptune_project_name'], api_token=cfg['neptune_api_token']) neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg) num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1) activation = 'sigmoid' if num_classes == 1 else 'softmax2d' background = False if cfg['ignore_channels'] else True aux_params = dict( pooling=cfg['pooling'], # one of 'avg', 'max' dropout=cfg['dropout'], # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=num_classes) # define number of output labels # configure model models = { 'unet': Unet(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'unetplusplus': UnetPlusPlus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], decoder_use_batchnorm=cfg['use_batchnorm'], classes=num_classes, activation=activation, aux_params=aux_params), 'deeplabv3plus': DeepLabV3Plus(encoder_name=cfg['encoder_name'], encoder_weights=cfg['encoder_weights'], classes=num_classes, activation=activation, aux_params=aux_params) } assert cfg['architecture'] in models.keys() model = models[cfg['architecture']] # configure loss losses = { 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']) } assert cfg['loss'] in losses.keys() loss = losses[cfg['loss']] # configure optimizer optimizers = { 'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]), 'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]), 'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])]) } assert cfg['optimizer'] in optimizers.keys() optimizer = optimizers[cfg['optimizer']] # configure metrics metrics = { 'dice_score': DiceMetric(include_background=background, reduction='mean'), 'dice_smp': Fscore(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'iou_smp': IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']), 'generalized_dice': GeneralizedDiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'dice_loss': DiceLoss(include_background=background, softmax=False, batch=cfg['combine']), 'cross_entropy': BCELoss(reduction='mean'), 'accuracy': Accuracy(ignore_channels=cfg['ignore_channels']) } assert all(m['name'] in metrics.keys() for m in cfg['metrics']) metrics = [(metrics[m['name']], m['name'], m['type']) for m in cfg['metrics']] # tuple of (metric, name, type) # configure scheduler schedulers = { 'steplr': StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5), 'cosine': CosineAnnealingLR(optimizer, cfg['epochs'], eta_min=cfg['eta_min'], last_epoch=-1) } assert cfg['scheduler'] in schedulers.keys() scheduler = schedulers[cfg['scheduler']] # configure augmentations train_transform = load_train_transform(transform_type=cfg['transform'], patch_size=cfg['patch_size']) valid_transform = load_valid_transform(patch_size=cfg['patch_size']) train_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=train_transform, normalize=cfg['normalize'], tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) valid_dataset = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=cfg['normalize']) # save intermediate augmentations if cfg['eval_dir']: default_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None) transform_dataset = SegmentationDataset(df_path=cfg['train_data'], transform=None, normalize=None, tissuemix=cfg['tissuemix'], probability=cfg['probability'], blending=cfg['blending'], warping=cfg['warping'], color=cfg['color']) for idx in range(0, min(500, len(default_dataset)), 10): image_input, image_mask = default_dataset[idx] image_input = image_input.transpose((1, 2, 0)) image_input = image_input.astype(np.uint8) image_mask = image_mask.transpose( 1, 2, 0) # Why do we need transpose here? image_mask = image_mask.astype(np.uint8) image_mask = image_mask.squeeze() image_mask = image_mask * 255 image_transform, _ = transform_dataset[idx] image_transform = image_transform.transpose( (1, 2, 0)).astype(np.uint8) idx_str = str(idx).zfill(3) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}a_image_input.png'), image_input, check_contrast=False) plt.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}b_image_mask.png'), image_mask, vmin=0, vmax=1) skimage.io.imsave(os.path.join(cfg['eval_dir'], f'{idx_str}c_image_transform.png'), image_transform, check_contrast=False) del transform_dataset train_loader = DataLoader(train_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=cfg['batch_size'], num_workers=cfg['workers'], shuffle=False) trainer = Trainer(model=model, device=cfg['device'], save_checkpoints=cfg['save_checkpoints'], checkpoint_dir=cfg['checkpoint_dir'], checkpoint_name=cfg['checkpoint_name']) trainer.compile(optimizer=optimizer, loss=loss, metrics=metrics, num_classes=num_classes) trainer.fit(train_loader, valid_loader, epochs=cfg['epochs'], scheduler=scheduler, verbose=cfg['verbose'], loss_weight=cfg['loss_weight']) # validation inference best_model = model best_model.load_state_dict( torch.load(os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))) best_model.to(cfg['device']) best_model.eval() # setup directory to save plots if os.path.isdir(cfg['plot_dir']): # remove existing dir and content shutil.rmtree(cfg['plot_dir']) # create absolute destination os.makedirs(cfg['plot_dir']) # valid dataset without transformations and normalization for image visualization valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'], transform=valid_transform, normalize=None) if cfg['save_checkpoints']: for n in range(len(valid_dataset)): image_vis = valid_dataset_vis[n][0].astype('uint8') image_vis = image_vis.transpose((1, 2, 0)) image, gt_mask = valid_dataset[n] gt_mask = gt_mask.transpose((1, 2, 0)) gt_mask = gt_mask.squeeze() x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0) pr_mask, _ = best_model.predict(x_tensor) pr_mask = pr_mask.cpu().numpy().round() pr_mask = pr_mask.squeeze() save_predictions(out_path=cfg['plot_dir'], index=n, image=image_vis, ground_truth_mask=gt_mask, predicted_mask=pr_mask, average='macro')
def main_worker(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f if not os.path.exists(args.dir): raise FileNotFoundError(f"Missing directory {args.dir}") # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") total_start = time.time() train_transforms = Compose([ # load 4 Nifti images and stack them together LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64], random_size=False), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandScaleIntensityd(keys="image", factors=0.1, prob=0.5), RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5), ToTensord(keys=["image", "label"]), ]) # create a training data loader train_ds = BratsCacheDataset( root_dir=args.dir, transform=train_transforms, section="training", num_workers=4, cache_rate=args.cache_rate, shuffle=True, ) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # validation transforms and dataset val_transforms = Compose([ LoadNiftid(keys=["image", "label"]), AsChannelFirstd(keys="image"), ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), Orientationd(keys=["image", "label"], axcodes="RAS"), CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]), NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), ToTensord(keys=["image", "label"]), ]) val_ds = BratsCacheDataset( root_dir=args.dir, transform=val_transforms, section="validation", num_workers=4, cache_rate=args.cache_rate, shuffle=False, ) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if dist.get_rank() == 0: # Logging for TensorBoard writer = SummaryWriter(log_dir=args.log_dir) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") if args.network == "UNet": model = UNet( dimensions=3, in_channels=4, out_channels=3, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) else: model = SegResNet(in_channels=4, out_channels=3, init_filters=16, dropout_prob=0.2).to(device) loss_function = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5, amsgrad=True) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[args.local_rank]) # start a typical PyTorch training total_epoch = args.epochs best_metric = -1000000 best_metric_epoch = -1 epoch_time = AverageMeter("Time", ":6.3f") progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ") end = time.time() print(f"Time elapsed before training: {end-total_start}") for epoch in range(total_epoch): train_loss = train(train_loader, model, loss_function, optimizer, epoch, args, device) epoch_time.update(time.time() - end) if epoch % args.print_freq == 0: progress.display(epoch) if dist.get_rank() == 0: writer.add_scalar("Loss/train", train_loss, epoch) if (epoch + 1) % args.val_interval == 0: metric, metric_tc, metric_wt, metric_et = evaluate( model, val_loader, device) if dist.get_rank() == 0: writer.add_scalar("Mean Dice/val", metric, epoch) writer.add_scalar("Mean Dice TC/val", metric_tc, epoch) writer.add_scalar("Mean Dice WT/val", metric_wt, epoch) writer.add_scalar("Mean Dice ET/val", metric_et, epoch) if metric > best_metric: best_metric = metric best_metric_epoch = epoch + 1 print( f"current epoch: {epoch + 1} current mean dice: {metric:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}" ) end = time.time() print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}") if dist.get_rank() == 0: print( f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" ) # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") writer.flush() dist.destroy_process_group()
def test_script(self): loss = DiceLoss() test_input = torch.ones(2, 1, 8, 8) test_script_save(loss, test_input, test_input)
def test_ill_shape(self): loss = DiceLoss() with self.assertRaisesRegex(AssertionError, ""): loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))
# http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import numpy as np import torch from parameterized import parameterized from monai.losses import DiceLoss from monai.losses.multi_scale import MultiScaleLoss dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5) TEST_CASES = [ [ {"loss": dice_loss, "scales": None, "kernel": "gaussian"}, {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.307576, ], [ {"loss": dice_loss, "scales": [0, 1], "kernel": "gaussian"}, {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, 0.463116, ], [ {"loss": dice_loss, "scales": [0, 1, 2], "kernel": "cauchy"}, {
def test_input_warnings(self): chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertWarns(Warning): loss = DiceLoss(include_background=False) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = DiceLoss(softmax=True) loss.forward(chn_input, chn_target) with self.assertWarns(Warning): loss = DiceLoss(to_onehot_y=True) loss.forward(chn_input, chn_target)
def loss_function(self, context: Context): return DiceLoss(to_onehot_y=True, softmax=True, squared_pred=True)