Esempio n. 1
0
 def test_ill_opts(self):
     with self.assertRaisesRegex(ValueError, ""):
         DiceLoss(sigmoid=True, softmax=True)
     chn_input = torch.ones((1, 1, 3))
     chn_target = torch.ones((1, 1, 3))
     with self.assertRaisesRegex(ValueError, ""):
         DiceLoss(reduction="unknown")(chn_input, chn_target)
     with self.assertRaisesRegex(ValueError, ""):
         DiceLoss(reduction=None)(chn_input, chn_target)
Esempio n. 2
0
 def test_result_onehot_target_include_bg(self):
     size = [3, 3, 5, 5]
     label = torch.randint(low=0, high=2, size=size)
     pred = torch.randn(size)
     for reduction in ["sum", "mean", "none"]:
         common_params = {
             "include_background": True,
             "to_onehot_y": False,
             "reduction": reduction
         }
         for focal_weight in [
                 None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)
         ]:
             for lambda_focal in [0.5, 1.0, 1.5]:
                 dice_focal = DiceFocalLoss(focal_weight=focal_weight,
                                            gamma=1.0,
                                            lambda_focal=lambda_focal,
                                            **common_params)
                 dice = DiceLoss(**common_params)
                 focal = FocalLoss(weight=focal_weight,
                                   gamma=1.0,
                                   **common_params)
                 result = dice_focal(pred, label)
                 expected_val = dice(
                     pred, label) + lambda_focal * focal(pred, label)
                 np.testing.assert_allclose(result, expected_val)
Esempio n. 3
0
def run_test(batch_size=64, train_steps=100, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            return im[None], seg[None].astype(np.float32)

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-4)
    src = DataLoader(_TestBatch(), batch_size=batch_size)

    trainer = create_supervised_trainer(net, opt, loss, device, False)

    trainer.run(src, 1)
    loss = trainer.state.output
    return loss
Esempio n. 4
0
    def test_epistemic_scoring(self):
        input_size = (20, 20, 20)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        keys = ["image", "label"]
        num_training_ims = 10
        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)

        transforms = Compose([
            AddChanneld(keys),
            CropForegroundd(keys, source_key="image"),
            DivisiblePadd(keys, 4),
        ])

        infer_transforms = Compose([
            AddChannel(),
            CropForeground(),
            DivisiblePad(4),
        ])

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds,
                                  batch_size=2,
                                  collate_fn=pad_list_data_collate)

        model = UNet(3, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(
                    device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        entropy_score = EpistemicScoring(model=model,
                                         transforms=infer_transforms,
                                         roi_size=[20, 20, 20],
                                         num_samples=10)
        # Call Individual Infer from Epistemic Scoring
        ip_stack = [test_data["image"], test_data["image"], test_data["image"]]
        ip_stack = np.array(ip_stack)
        score_3d = entropy_score.entropy_3d_volume(ip_stack)
        score_3d_sum = np.sum(score_3d)
        # Call Entropy Metric from Epistemic Scoring
        self.assertEqual(score_3d.shape, input_size)
        self.assertIsInstance(score_3d_sum, np.float32)
        self.assertGreater(score_3d_sum, 3.0)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __init__(self, transforms):
            self.transforms = transforms

        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            seed = np.random.randint(2147483647)
            self.transforms.set_random_state(seed=seed)
            im = self.transforms(im)
            self.transforms.set_random_state(seed=seed)
            seg = self.transforms(seg)
            return im, seg

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-2)
    train_transforms = Compose([
        AddChannel(),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(),
        ToTensor()
    ])

    src = DataLoader(_TestBatch(train_transforms),
                     batch_size=batch_size,
                     shuffle=True)

    net.train()
    epoch_loss = 0
    step = 0
    for img, seg in src:
        step += 1
        opt.zero_grad()
        output = net(img.to(device))
        step_loss = loss(output, seg.to(device))
        step_loss.backward()
        opt.step()
        epoch_loss += step_loss.item()
    epoch_loss /= step

    return epoch_loss, step
 def __init__(self, focal):
     super(Loss, self).__init__()
     self.dice = DiceLoss(include_background=False,
                          softmax=True,
                          to_onehot_y=True,
                          batch=True)
     self.focal = FocalLoss(gamma=2.0)
     self.cross_entropy = nn.CrossEntropyLoss()
     self.use_focal = focal
Esempio n. 7
0
    def __init__(self, model: nn.Module, lr: float = 1e-4):
        """Creates a simple UNet for segmenting nodules in chest CTs. Defines its training and validation logic.

        Args:
            model: (nn.Module): Model to use for training.
            lr (float, optional): Initial learning rate. Defaults to 1e-4.
        """
        super().__init__()
        self.model = model
        self.loss = DiceLoss(to_onehot_y=True, softmax=True)
        self.lr = lr
        self.save_hyperparameters()
        return
Esempio n. 8
0
    def training_step(self, batch, batch_idx):
        inputs, targets = self.prepare_batch(batch)
        pred = self(inputs)
        # diceloss = DiceLoss(include_background=True, to_onehot_y=True)
        # loss = diceloss.forward(input=probs, target=targets)
        # dice, iou, _, _ = get_score(batch_preds, batch_targets, include_background=True)
        # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True)
        # loss = gdloss.forward(input=batch_preds, target=batch_targets)
        # if batch_idx != 0 and ((self.current_epoch >= 1 and dice.item() < 0.5) or batch_idx % 100 == 0):
        #     input = inputs.chunk(inputs.size()[0], 0)[0]  # split into 1 in the dimension 0
        #     target = targets.chunk(targets.size()[0], 0)[0]  # split into 1 in the dimension 0
        #     prob = probs.chunk(probs.size()[0], 0)[0]  # split into 1 in the dimension 0
        #     # really have problem in there, need to fix it
        #     dice_score, _, _, _ = get_score(torch.unsqueeze(prob, 0), torch.unsqueeze(target, 0))
        #     log_all_info(self, input, target, prob, batch_idx, "training", dice_score.item())
        # loss = F.binary_cross_entropy_with_logits(logits, targets)
        diceloss = DiceLoss(include_background=self.hparams.include_background,
                            to_onehot_y=True)
        loss = diceloss.forward(input=pred, target=targets)
        # What is the loos I need to set here? when I am using the batch size?

        # gdloss = GeneralizedDiceLoss(include_background=True, to_onehot_y=True)
        # loss = gdloss.forward(input=batch_preds, target=batch_targets)

        # I cannot use this `TrainResult` right now
        # the loss for prog_bar is not corrected, is there anything I write wrong?
        result = pl.TrainResult(minimize=loss)
        # logs metrics for each training_step, to the progress bar and logger
        result.log("train_loss",
                   loss,
                   prog_bar=True,
                   sync_dist=True,
                   logger=True,
                   reduce_fx=torch.mean,
                   on_step=True,
                   on_epoch=False)
        # we cannot compute the matrixs on the patches, because they do not contain all the 138 segmentations
        # So they would return 0 on some of the classes, making the matrixs not accurate
        return result
Esempio n. 9
0
test_loader = DataLoader(test_ds, batch_size=1, num_workers=0)
"""## Create Model, Loss, Optimizer"""

# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
"""## Execute a typical PyTorch training process"""

epoch_num = 2
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
post_label = AsDiscrete(to_onehot=True, n_classes=2)

for epoch in range(epoch_num):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epoch_num}")
val_loader = DataLoader(val_ds, batch_size=1)

# Create Model, Loss, Optimizer
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(
    dimensions=3,  #
    in_channels=1,  #
    out_channels=1,  #
    channels=(16, 32, 64, 128, 256),  #
    strides=(2, 2, 2, 2),  #
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)  #

loss_function = DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

# Execute a typical PyTorch training process

max_epochs = 50
# max_epochs = 300
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = AsDiscrete(threshold_values=True, n_classes=1)
post_label = AsDiscrete(n_classes=1)

for epoch in range(max_epochs):
Esempio n. 11
0
 def __init__(self, focal):
     super(LossBraTS, self).__init__()
     self.dice = DiceLoss(sigmoid=True, batch=True)
     self.ce = FocalLoss(
         gamma=2.0, to_onehot_y=False) if focal else nn.BCEWithLogitsLoss()
Esempio n. 12
0
 def __init__(
     self,
     include_background: bool = False,
     to_onehot_y: bool = False,
     sigmoid: bool = False,
     softmax: bool = False,
     other_act: Optional[Callable] = None,
     squared_pred: bool = False,
     jaccard: bool = False,
     reduction: str = "mean",
     smooth_nr: float = 1e-5,
     smooth_dr: float = 1e-5,
     batch: bool = False,
     topK: int = 10,
     ce_weight: Optional[torch.Tensor] = None,
     lambda_dice: float = 1.0,
     lambda_ce: float = 1.0,
 ) -> None:
     """
     Args:
         ``reduction`` is used for both losses and other parameters are only used for dice loss.
         include_background: if False channel index 0 (background category) is excluded from the calculation.
         to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
         sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
             don't need to specify activation function for `CrossEntropyLoss`.
         softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
             don't need to specify activation function for `CrossEntropyLoss`.
         other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
             other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
             only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
         squared_pred: use squared versions of targets and predictions in the denominator or not.
         jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
         reduction: {``"mean"``, ``"sum"``}
             Specifies the reduction to apply to the output. Defaults to ``"mean"``. The dice loss should
             as least reduce the spatial dimensions, which is different from cross entropy loss, thus here
             the ``none`` option cannot be used.
             - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
             - ``"sum"``: the output will be summed.
         smooth_nr: a small constant added to the numerator to avoid zero.
         smooth_dr: a small constant added to the denominator to avoid nan.
         batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
             Defaults to False, a Dice loss value is computed independently from each item in the batch
             before any `reduction`.
         topK: k of top-k.
         ce_weight: a rescaling weight given to each class for cross entropy loss.
         lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
             Defaults to 1.0.
         lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
             Defaults to 1.0.
     """
     super().__init__()
     self.dice = DiceLoss(
         include_background=include_background,
         to_onehot_y=to_onehot_y,
         sigmoid=sigmoid,
         softmax=softmax,
         other_act=other_act,
         squared_pred=squared_pred,
         jaccard=jaccard,
         reduction=reduction,
         smooth_nr=smooth_nr,
         smooth_dr=smooth_dr,
         batch=batch,
     )
     self.cross_entropy = nn.CrossEntropyLoss(
         weight=ce_weight,
         reduction='none',
     )
     if lambda_dice < 0.0:
         raise ValueError("lambda_dice should be no less than 0.0.")
     if lambda_ce < 0.0:
         raise ValueError("lambda_ce should be no less than 0.0.")
     self.lambda_dice = lambda_dice
     self.lambda_ce = lambda_ce
     self.k = topK
def main(cfg):
    """Runs main training procedure."""

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    print('Preparing model and data...')
    print('Using SMP version:', smp.__version__)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True
    binary = True if num_classes == 1 else False
    softmax = False if num_classes == 1 else True
    sigmoid = True if num_classes == 1 else False

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'pspnet':
        PSPNet(encoder_name=cfg['encoder_name'],
               encoder_weights=cfg['encoder_weights'],
               classes=num_classes,
               activation=activation,
               aux_params=aux_params),
        'pan':
        PAN(encoder_name=cfg['encoder_name'],
            encoder_weights=cfg['encoder_weights'],
            classes=num_classes,
            activation=activation,
            aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # TODO: Fix metric names

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size_train'])
    valid_transform = load_valid_transform(
        patch_size=cfg['patch_size_valid'])  # manually selected patch size

    train_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                    classes=cfg['classes'],
                                    transform=train_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    valid_dataset = ArtifactDataset(df_path=cfg['valid_data'],
                                    classes=cfg['classes'],
                                    transform=valid_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    test_dataset = ArtifactDataset(df_path=cfg['test_data'],
                                   classes=cfg['classes'],
                                   transform=valid_transform,
                                   normalize=cfg['normalize'],
                                   ink_filters=cfg['ink_filters'])

    # load pre-sampled patch arrays
    train_image, train_mask = train_dataset[0]
    valid_image, valid_mask = valid_dataset[0]
    print('Shape of image patch', train_image.shape)
    print('Shape of mask patch', train_mask.shape)
    print('Train dataset shape:', len(train_dataset))
    print('Valid dataset shape:', len(valid_dataset))
    assert train_image.shape[1] == cfg[
        'patch_size_train'] and train_image.shape[2] == cfg['patch_size_train']
    assert valid_image.shape[1] == cfg[
        'patch_size_valid'] and valid_image.shape[2] == cfg['patch_size_valid']

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                          classes=cfg['classes'],
                                          transform=None,
                                          normalize=None,
                                          ink_filters=cfg['ink_filters'])

        transform_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                            classes=cfg['classes'],
                                            transform=train_transform,
                                            normalize=None,
                                            ink_filters=cfg['ink_filters'])

        for idx in range(0, min(500, len(train_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0)).astype(np.uint8)

            image_mask = image_mask.transpose(1, 2, 0)
            image_mask = np.argmax(
                image_mask, axis=2) if not binary else image_mask.squeeze()
            image_mask = image_mask.astype(np.uint8)

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=6,
                       cmap='Spectral')
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    # update process
    print('Starting training...')
    print('Available GPUs for training:', torch.cuda.device_count())

    # pytorch module wrapper
    class DataParallelModule(torch.nn.DataParallel):
        def __getattr__(self, name):
            try:
                return super().__getattr__(name)
            except AttributeError:
                return getattr(self.module, name)

    # data parallel training
    if torch.cuda.device_count() > 1:
        model = DataParallelModule(model)

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=int(cfg['batch_size'] / 4),
                              num_workers=cfg['workers'],
                              shuffle=False)

    test_loader = DataLoader(test_dataset,
                             batch_size=int(cfg['batch_size'] / 4),
                             num_workers=cfg['workers'],
                             shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'],
                test_loader=test_loader,
                binary=binary)

    # validation inference
    model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    model.to(cfg['device'])
    model.eval()

    # save best checkpoint to neptune
    neptune.log_artifact(
        os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir_valid']):
        shutil.rmtree(cfg['plot_dir_valid'])
    os.makedirs(cfg['plot_dir_valid'], exist_ok=True)

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = ArtifactDataset(df_path=cfg['valid_data'],
                                        classes=cfg['classes'],
                                        ink_filters=cfg['ink_filters'])

    # keep track of valid masks
    valid_preds = []
    valid_masks = []

    if cfg['save_checkpoints']:
        print('Predicting valid patches...')
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            valid_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            valid_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_valid'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

    del train_dataset, valid_dataset
    del train_loader, valid_loader

    # calculate dice per class
    valid_masks = np.stack(valid_masks, axis=0)
    valid_masks = valid_masks.flatten()
    valid_preds = np.stack(valid_preds, axis=0)
    valid_preds = valid_preds.flatten()
    dice_score = f1_score(y_true=valid_masks, y_pred=valid_preds, average=None)
    neptune.log_text('valid_dice_class', str(dice_score))
    print('Valid dice score (class):', str(dice_score))

    if cfg['evaluate_test_set']:
        print('Predicting test patches...')

        # setup directory to save plots
        if os.path.isdir(cfg['plot_dir_test']):
            shutil.rmtree(cfg['plot_dir_test'])
        os.makedirs(cfg['plot_dir_test'], exist_ok=True)

        # test dataset without transformations and normalization for image visualization
        test_dataset_vis = ArtifactDataset(df_path=cfg['test_data'],
                                           classes=cfg['classes'],
                                           ink_filters=cfg['ink_filters'])

        # keep track of test masks
        test_masks = []
        test_preds = []

        for n in range(len(test_dataset)):
            image_vis = test_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = test_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            test_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            test_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_test'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

            # calculate dice per class
            test_masks = np.stack(test_masks, axis=0)
            test_masks = test_masks.flatten()
            test_preds = np.stack(test_preds, axis=0)
            test_preds = test_preds.flatten()
            dice_score = f1_score(y_true=test_masks,
                                  y_pred=test_preds,
                                  average=None)
            neptune.log_text('test_dice_class', str({dice_score}))
            print('Test dice score (class):', str(dice_score))

    # end of training process
    print('Finished training!')
Esempio n. 14
0
 def test_shape(self, input_param, input_data, expected_val):
     result = DiceLoss(**input_param).forward(**input_data)
     np.testing.assert_allclose(result.detach().cpu().numpy(),
                                expected_val,
                                rtol=1e-5)
    def test_test_time_augmentation(self):
        input_size = (20, 40)  # test different input data shape to pad list collate
        keys = ["image", "label"]
        num_training_ims = 10

        train_data = self.get_data(num_training_ims, input_size)
        test_data = self.get_data(1, input_size)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        transforms = Compose(
            [
                AddChanneld(keys),
                RandAffined(
                    keys,
                    prob=1.0,
                    spatial_size=(30, 30),
                    rotate_range=(np.pi / 3, np.pi / 3),
                    translate_range=(3, 3),
                    scale_range=((0.8, 1), (0.8, 1)),
                    padding_mode="zeros",
                    mode=("bilinear", "nearest"),
                    as_tensor_output=False,
                ),
                CropForegroundd(keys, source_key="image"),
                DivisiblePadd(keys, 4),
            ]
        )

        train_ds = CacheDataset(train_data, transforms)
        # output might be different size, so pad so that they match
        train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)

        model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device)
        loss_function = DiceLoss(sigmoid=True)
        optimizer = torch.optim.Adam(model.parameters(), 1e-3)

        num_epochs = 10
        for _ in trange(num_epochs):
            epoch_loss = 0

            for batch_data in train_loader:
                inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            epoch_loss /= len(train_loader)

        post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

        tt_aug = TestTimeAugmentation(
            transform=transforms,
            batch_size=5,
            num_workers=0,
            inferrer_fn=model,
            device=device,
            to_tensor=True,
            output_device="cpu",
            post_func=post_trans,
        )
        mode, mean, std, vvc = tt_aug(test_data)
        self.assertEqual(mode.shape, (1,) + input_size)
        self.assertEqual(mean.shape, (1,) + input_size)
        self.assertTrue(all(np.unique(mode) == (0, 1)))
        self.assertGreaterEqual(mean.min(), 0.0)
        self.assertLessEqual(mean.max(), 1.0)
        self.assertEqual(std.shape, (1,) + input_size)
        self.assertIsInstance(vvc, float)
Esempio n. 16
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
Esempio n. 17
0
 def __init__(self, loss):
     super().__init__()
     self.loss = loss
     self.focal = FocalLoss(gamma=2.0)
     self.dice_bg = DiceLoss(include_background=True, softmax=True, to_onehot_y=True, batch=True)
     self.dice_nbg = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
Esempio n. 18
0
def create_trainer(args):
    set_determinism(seed=args.seed)

    multi_gpu = args.multi_gpu
    local_rank = args.local_rank
    if multi_gpu:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device("cuda:{}".format(local_rank))
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda" if args.use_gpu else "cpu")

    pre_transforms = get_pre_transforms(args.roi_size, args.model_size,
                                        args.dimensions)
    click_transforms = get_click_transforms()
    post_transform = get_post_transforms()

    train_loader, val_loader = get_loaders(args, pre_transforms)

    # define training components
    network = get_network(args.network, args.channels,
                          args.dimensions).to(device)
    if multi_gpu:
        network = torch.nn.parallel.DistributedDataParallel(
            network, device_ids=[local_rank], output_device=local_rank)

    if args.resume:
        logging.info('{}:: Loading Network...'.format(local_rank))
        map_location = {"cuda:0": "cuda:{}".format(local_rank)}
        network.load_state_dict(
            torch.load(args.model_filepath, map_location=map_location))

    # define event-handlers for engine
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=args.output,
                                output_transform=lambda x: None),
        DeepgrowStatsHandler(log_dir=args.output,
                             tag_name='val_dice',
                             image_interval=args.image_interval),
        CheckpointSaver(save_dir=args.output,
                        save_dict={"net": network},
                        save_key_metric=True,
                        save_final=True,
                        save_interval=args.save_interval,
                        final_filename='model.pt')
    ]
    val_handlers = val_handlers if local_rank == 0 else None

    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_val_interactions,
            key_probability='probability',
            train=False),
        inferer=SimpleInferer(),
        post_transform=post_transform,
        key_val_metric={
            "val_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        val_handlers=val_handlers)

    loss_function = DiceLoss(sigmoid=True, squared_pred=True)
    optimizer = torch.optim.Adam(network.parameters(), args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5000,
                                                   gamma=0.1)

    train_handlers = [
        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
        ValidationHandler(validator=evaluator,
                          interval=args.val_freq,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(log_dir=args.output,
                                tag_name="train_loss",
                                output_transform=lambda x: x["loss"]),
        CheckpointSaver(save_dir=args.output,
                        save_dict={
                            "net": network,
                            "opt": optimizer,
                            "lr": lr_scheduler
                        },
                        save_interval=args.save_interval * 2,
                        save_final=True,
                        final_filename='checkpoint.pt'),
    ]
    train_handlers = train_handlers if local_rank == 0 else train_handlers[:2]

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=args.epochs,
        train_data_loader=train_loader,
        network=network,
        iteration_update=Interaction(
            transforms=click_transforms,
            max_interactions=args.max_train_interactions,
            key_probability='probability',
            train=True),
        optimizer=optimizer,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=post_transform,
        amp=args.amp,
        key_train_metric={
            "train_dice":
            MeanDice(include_background=False,
                     output_transform=lambda x: (x["pred"], x["label"]))
        },
        train_handlers=train_handlers,
    )
    return trainer
Esempio n. 19
0
 def loss_function(self, context: Context):
     return DiceLoss(sigmoid=True, squared_pred=True)
Esempio n. 20
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)
Esempio n. 21
0
def train_process(fast=False):
    epoch_num = 10
    val_interval = 1
    train_trans, val_trans = transformations()
    train_ds = Dataset(data=train_files, transform=train_trans)
    val_ds = Dataset(data=val_files, transform=val_trans)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n1 = 16
    model = UNet(dimensions=3,
                 in_channels=1,
                 out_channels=2,
                 channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16),
                 strides=(2, 2, 2, 2)).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = list()
    metric_values = list()

    for epoch in range(epoch_num):
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['image'].to(
                device), batch_data['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = val_data['image'].to(
                        device), val_data['label'].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = post_pred(val_outputs)
                    val_labels = post_label(val_labels)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    epochs_no_improve = 0
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    torch.save(model.state_dict(), 'sLUMRTL644.pth')
                else:
                    epochs_no_improve += 1

            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
Esempio n. 22
0
 def test_shape(self, input_param, input_data, expected_val):
     result = DiceLoss(**input_param).forward(**input_data)
     self.assertAlmostEqual(result.item(), expected_val, places=5)
Esempio n. 23
0
    def compute_from_aggregating(self,
                                 input,
                                 target,
                                 if_path: bool,
                                 type_as_tensor=None,
                                 whether_to_return_img=False,
                                 result: pl.EvalResult = None):
        transform = get_val_transform()
        if if_path:
            cur_img_subject = torchio.Subject(
                img=torchio.Image(input, type=torchio.INTENSITY))
            cur_label_subject = torchio.Subject(
                img=torchio.Image(target, type=torchio.LABEL))

            preprocessed_img = transform(cur_img_subject)
            preprocessed_label = transform(cur_label_subject)

            patch_overlap = self.hparams.patch_overlap  # is there any constrain?
            grid_sampler = torchio.inference.GridSampler(
                preprocessed_img,
                self.patch_size,
                patch_overlap,
            )

            patch_loader = torch.utils.data.DataLoader(grid_sampler)
            aggregator = torchio.inference.GridAggregator(grid_sampler)

            for patches_batch in patch_loader:
                input_tensor = patches_batch['img'][torchio.DATA]
                # used to convert tensor to CUDA
                input_tensor = input_tensor.type_as(type_as_tensor['val_dice'])
                locations = patches_batch[torchio.LOCATION]
                preds = self(input_tensor)  # use cuda
                labels = preds.argmax(dim=torchio.CHANNELS_DIMENSION,
                                      keepdim=True)  # use cuda
                aggregator.add_batch(labels, locations)
            output_tensor = aggregator.get_output_tensor()  # not using cuda!

            if if_path or whether_to_return_img:
                return preprocessed_img.img.data, output_tensor, preprocessed_label.img.data
            else:
                return output_tensor, preprocessed_label.img.data

        else:
            cur_subject = torchio.Subject(
                img=torchio.Image(tensor=input.squeeze(),
                                  type=torchio.INTENSITY),
                label=torchio.Image(tensor=target.squeeze(),
                                    type=torchio.LABEL))
            preprocessed_subject = transform(cur_subject)

            patch_overlap = self.hparams.patch_overlap  # is there any constrain?
            grid_sampler = torchio.inference.GridSampler(
                preprocessed_subject,
                self.patch_size,
                patch_overlap,
            )

            patch_loader = torch.utils.data.DataLoader(grid_sampler)
            aggregator = torchio.inference.GridAggregator(grid_sampler)

            dice_loss = []

            for patches_batch in patch_loader:
                input_tensor, target_tensor = patches_batch['img'][
                    torchio.DATA], patches_batch['label'][torchio.DATA]
                # used to convert tensor to CUDA
                input_tensor = input_tensor.type_as(input)
                locations = patches_batch[torchio.LOCATION]
                preds_tensor = self(input_tensor)  # use cuda
                # Compute the loss here
                diceloss = DiceLoss(
                    include_background=self.hparams.include_background,
                    to_onehot_y=True)
                loss = diceloss.forward(input=preds_tensor,
                                        target=target_tensor)
                dice_loss.append(loss)
                labels = preds_tensor.argmax(dim=torchio.CHANNELS_DIMENSION,
                                             keepdim=True)  # use cuda
                aggregator.add_batch(labels, locations)
            output_tensor = aggregator.get_output_tensor(
            )  # not using cuda!!!!

            if whether_to_return_img:
                return cur_subject['img'].data, output_tensor, cur_subject[
                    'label'].data
            else:
                return output_tensor, cur_subject['label'].data, torch.stack(
                    dice_loss)
Esempio n. 24
0
def main(cfg):
    """Runs main training procedure."""

    print('Starting training...')
    print('Current working directory is:', os.getcwd())

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'unetplusplus':
        UnetPlusPlus(encoder_name=cfg['encoder_name'],
                     encoder_weights=cfg['encoder_weights'],
                     decoder_use_batchnorm=cfg['use_batchnorm'],
                     classes=num_classes,
                     activation=activation,
                     aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=False,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=False,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size'])
    valid_transform = load_valid_transform(patch_size=cfg['patch_size'])

    train_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                        transform=train_transform,
                                        normalize=cfg['normalize'],
                                        tissuemix=cfg['tissuemix'],
                                        probability=cfg['probability'],
                                        blending=cfg['blending'],
                                        warping=cfg['warping'],
                                        color=cfg['color'])

    valid_dataset = SegmentationDataset(df_path=cfg['valid_data'],
                                        transform=valid_transform,
                                        normalize=cfg['normalize'])

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                              transform=None,
                                              normalize=None)

        transform_dataset = SegmentationDataset(df_path=cfg['train_data'],
                                                transform=None,
                                                normalize=None,
                                                tissuemix=cfg['tissuemix'],
                                                probability=cfg['probability'],
                                                blending=cfg['blending'],
                                                warping=cfg['warping'],
                                                color=cfg['color'])

        for idx in range(0, min(500, len(default_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0))
            image_input = image_input.astype(np.uint8)

            image_mask = image_mask.transpose(
                1, 2, 0)  # Why do we need transpose here?
            image_mask = image_mask.astype(np.uint8)
            image_mask = image_mask.squeeze()
            image_mask = image_mask * 255

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=1)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'])

    # validation inference
    best_model = model
    best_model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    best_model.to(cfg['device'])
    best_model.eval()

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir']):
        # remove existing dir and content
        shutil.rmtree(cfg['plot_dir'])
    # create absolute destination
    os.makedirs(cfg['plot_dir'])

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = SegmentationDataset(df_path=cfg['valid_data'],
                                            transform=valid_transform,
                                            normalize=None)

    if cfg['save_checkpoints']:
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose((1, 2, 0))

            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose((1, 2, 0))
            gt_mask = gt_mask.squeeze()

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = best_model.predict(x_tensor)
            pr_mask = pr_mask.cpu().numpy().round()
            pr_mask = pr_mask.squeeze()

            save_predictions(out_path=cfg['plot_dir'],
                             index=n,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask,
                             average='macro')
Esempio n. 25
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)

            if dist.get_rank() == 0:
                writer.add_scalar("Mean Dice/val", metric, epoch)
                writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
                writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
                writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()
Esempio n. 26
0
 def test_script(self):
     loss = DiceLoss()
     test_input = torch.ones(2, 1, 8, 8)
     test_script_save(loss, test_input, test_input)
Esempio n. 27
0
 def test_ill_shape(self):
     loss = DiceLoss()
     with self.assertRaisesRegex(AssertionError, ""):
         loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))
Esempio n. 28
0
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses import DiceLoss
from monai.losses.multi_scale import MultiScaleLoss

dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5)

TEST_CASES = [
    [
        {"loss": dice_loss, "scales": None, "kernel": "gaussian"},
        {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},
        0.307576,
    ],
    [
        {"loss": dice_loss, "scales": [0, 1], "kernel": "gaussian"},
        {"y_pred": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "y_true": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},
        0.463116,
    ],
    [
        {"loss": dice_loss, "scales": [0, 1, 2], "kernel": "cauchy"},
        {
Esempio n. 29
0
 def test_input_warnings(self):
     chn_input = torch.ones((1, 1, 3))
     chn_target = torch.ones((1, 1, 3))
     with self.assertWarns(Warning):
         loss = DiceLoss(include_background=False)
         loss.forward(chn_input, chn_target)
     with self.assertWarns(Warning):
         loss = DiceLoss(softmax=True)
         loss.forward(chn_input, chn_target)
     with self.assertWarns(Warning):
         loss = DiceLoss(to_onehot_y=True)
         loss.forward(chn_input, chn_target)
 def loss_function(self, context: Context):
     return DiceLoss(to_onehot_y=True, softmax=True, squared_pred=True)