Ejemplo n.º 1
0
 def update(self, y_pred, y, batched=True):
     if not batched:
         y_pred = y_pred[None]
         y = y[None]
     score = compute_meandice(y_pred=y_pred, y=y,
                              include_background=False).mean()
     self.data.append(score.item())
Ejemplo n.º 2
0
def run_inference_test(root_dir, device=torch.device("cuda:0")):
    images = sorted(glob(os.path.join(root_dir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inferene need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    model.load_state_dict(torch.load(model_filename))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"),
                           dtype=int)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            sw_batch_size, roi_size = 4, (96, 96, 96)
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
    return metric
Ejemplo n.º 3
0
    def update(self, y_pred, y):
        y_pred = y_pred if torch.is_tensor(y_pred) else torch.from_numpy(
            y_pred)
        y = y if torch.is_tensor(y) else torch.from_numpy(y)

        score = compute_meandice(y_pred=y_pred, y=y,
                                 include_background=True).mean().item()
        if not math.isnan(score):
            self.data.append(score)
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())

    device = torch.device('cuda:0')
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load('best_metric_model.pth'))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.
        metric_count = 0
        saver = NiftiSaver(output_dir='./output')
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                     to_onehot_y=False, add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print('evaluation metric:', metric)
    shutil.rmtree(tempdir)
Ejemplo n.º 5
0
 def validation_step(self, batch, batch_idx):
     images, labels = batch["image"], batch["label"]
     roi_size = PATCH_SIZE
     sw_batch_size = 1
     outputs = sliding_window_inference(images, roi_size, sw_batch_size,
                                        self.forward)
     loss = self.loss_function(outputs, labels)
     outputs = self.post_pred(outputs)
     labels = self.post_label(labels)
     value = compute_meandice(y_pred=outputs,
                              y=labels,
                              include_background=False)
     return {"val_loss": loss, "val_dice": value}
Ejemplo n.º 6
0
def MeanDice(model, data_loader, device):
    metric_sum = 0.0
    metric_count = 0
    for data in data_loader:
        inputs, labels = (
            data["image"].to(device),
            data["label"].to(device),
        )
        outputs = model(inputs)
        value = compute_meandice(outputs, inputs, sigmoid=True, logit_thresh=0.5)
        metric_count += len(value)
        metric_sum += value.sum().item()
    return metric_sum / metric_count
Ejemplo n.º 7
0
    def update(self, output: Sequence[Union[torch.Tensor, dict]]):
        assert len(
            output) == 2, 'MeanDice metric can only support y_pred and y.'
        y_pred, y = output
        scores = compute_meandice(y_pred, y, self.include_background,
                                  self.to_onehot_y, self.mutually_exclusive,
                                  self.add_sigmoid, self.logit_thresh)

        # add all items in current batch
        for batch in scores:
            not_nan = ~torch.isnan(batch)
            if not_nan.sum() == 0:
                continue
            class_avg = batch[not_nan].mean().item()
            self._sum += class_avg
            self._num_examples += 1
Ejemplo n.º 8
0
 def test_nans(self, input_data, expected_value):
     result = compute_meandice(**input_data)
     self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
Ejemplo n.º 9
0
 def test_value(self, input_data, expected_value):
     result = compute_meandice(**input_data)
     np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
Ejemplo n.º 10
0
 with torch.no_grad():
     metric_sum = 0.
     metric_count = 0
     val_images = None
     val_labels = None
     val_outputs = None
     for val_data in val_loader:
         val_images, val_labels = val_data['img'].to(
             device), val_data['seg'].to(device)
         roi_size = (96, 96, 96)
         sw_batch_size = 4
         val_outputs = sliding_window_inference(val_images, roi_size,
                                                sw_batch_size, model)
         value = compute_meandice(y_pred=val_outputs,
                                  y=val_labels,
                                  include_background=True,
                                  to_onehot_y=False,
                                  add_sigmoid=True)
         metric_count += len(value)
         metric_sum += value.sum().item()
     metric = metric_sum / metric_count
     metric_values.append(metric)
     if metric > best_metric:
         best_metric = metric
         best_metric_epoch = epoch + 1
         torch.save(model.state_dict(), 'best_metric_model.pth')
         print('saved new best metric model')
     print(
         "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
         % (epoch + 1, metric, best_metric, best_metric_epoch))
     writer.add_scalar('val_mean_dice', metric, epoch + 1)
Ejemplo n.º 11
0
            metric_sum = 0.0
            metric_count = 0

            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data['image'].to(device),
                    val_data['label'].to(device),
                )
                roi_size = PATCH_SIZE
                sw_batch_size = 1
                val_outputs = sliding_window_inference(val_inputs, roi_size,
                                                       sw_batch_size, model)
                value = compute_meandice(
                    y_pred=val_outputs,
                    y=val_labels,
                    include_background=False,
                    to_onehot_y=True,
                    mutually_exclusive=True,
                )
                metric_count += len(value)
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(),
                           output_path / 'best_metric_model.pth')
                print('saved new best metric model')
            print(
Ejemplo n.º 12
0
def inference(data_loader, model, criterion, model_name):
    "Network ready for the inference step. Plots the predicted LA mask results together with the ground truth mask. "

    total_batch = len(data_loader)

    test_loss = average_metrics()

    dice, hausdorff = average_metrics(), average_metrics()

    # Softmax
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)

    # set model eval mode!
    model.eval()
    results = []
    dice_collection, hasudorff_collection = [], []
    start2 = time.time()
    for batch_idx, (data, y) in enumerate(data_loader):
        result = None
        start = time.time()
        data = data.to(device)
        data = data.type(torch.cuda.FloatTensor)
        y = y.to(device)

        # Get prediction
        out = model(data.to(device))

        # Get loss
        loss = criterion(out.to(device), y.to(device))

        # Backpropagate error
        loss.backward()

        # Update loss
        test_loss.update(loss.item())

        # Evaluation
        outputs = post_pred(out.to(device))
        labels = post_label(y.to(device))

        # Post-processing
        post_processing = remove_small_objects(
            torch.argmax(out, dim=1).detach().cpu()[0, :, :, :])

        # Metrice
        dice.update(
            compute_meandice(
                y_pred=post_processing.to(device),
                y=labels,
                include_background=False,
            ).item())
        hausdorff.update(
            compute_hausdorff_distance(y_pred=post_processing.to(device),
                                       y=labels,
                                       distance_metric='euclidean').item())
        dice_collection.append(dice.val)
        hasudorff_collection.append(hausdorff.val)

        print(
            f'Iteration {(batch_idx + 1)}/{total_batch} - Loss: {test_loss.val}  -Dice: {dice.val} - Hausdorff: {hausdorff.val} '
        )
        end = time.time()

        result = [
            batch_idx + 1, test_loss.val, dice.val, hausdorff.val, end - start
        ]
        results.append(result)

        # PLOTS
        plt.figure("check", (10, 5))
        slice = out.shape[4] // 2
        ground_truth = rotate(y.detach().cpu()[0, 0, :, :, slice], 90)
        mri_image = rotate(data.detach().cpu()[0, 0, :, :, slice], 90)
        predicted = rotate(
            torch.argmax(post_processing, dim=1).detach().cpu()[0, :, :,
                                                                slice], 90)
        custom = matplotlib.colors.ListedColormap(['gray', 'red'])

        # LGE - MRI
        plt.subplot(1, 3, 1)
        plt.title(f'MRI test {batch_idx+1}')
        plt.axis('off')
        plt.imshow(mri_image, cmap="gray")

        # Ground truth mask + LGE-MRI
        plt.subplot(1, 3, 2)
        plt.axis('off')
        plt.title('Ground Truth')
        plt.imshow(mri_image, cmap="gray")
        plt.imshow(ground_truth, alpha=0.4, cmap=custom)

        # Predicted mask + LGE-MRI
        plt.subplot(1, 3, 3)
        plt.title(f'DSC:{round(dice.val,3)} - HD:{round(hausdorff.val,2)}')
        plt.axis('off')
        plt.imshow(mri_image, cmap="gray")
        plt.imshow(predicted, alpha=0.4, cmap=custom)
        plt.show()

        # Save mri, gt masks and mask predictions into VTK format
        y_predicted_array = torch.argmax(
            post_processing,
            dim=1).detach().cpu().numpy().astype('float')[0, :, :, :]
        y_true_array = y.detach().cpu().numpy()[0, 0, :, :, :]
        mri_array = data.detach().cpu().numpy()[0, 0, :, :, :]

        file = f'patient_{batch_idx+1}.vtk'
        generate_vtk_from_numpy(
            y_predicted_array,
            'test_results/' + model_name + '/predicted_' + file)
        generate_vtk_from_numpy(y_true_array,
                                'test_results/' + model_name + '/true_' + file)
        generate_vtk_from_numpy(mri_array,
                                'test_results/' + model_name + '/mri_' + file)

    print(
        f'Test loss:{test_loss.avg} - Dice: {dice.avg} +/- {np.std(dice_collection)} - Hausdorff:{hausdorff.avg} +/- {np.std(hasudorff_collection)}'
    )
    end2 = time.time()

    # Save metric results into csv file
    results.append(
        ['average', test_loss.avg, dice.avg, hausdorff.val, end2 - start2])
    results_df = pd.DataFrame(
        results, columns=["test num", "Loss", "Dice", "Hausdorff", 'Time'])
    results_df.to_csv('test_results/' + model_name + '/metrics_vals.csv',
                      index=False)
Ejemplo n.º 13
0
    def __call__(self, engine: Engine):
        batch_data = engine.state.batch
        output_data = engine.state.output
        device = engine.state.device
        tag = ""
        if torch.distributed.is_initialized():
            tag = "r{}-".format(torch.distributed.get_rank())

        for bidx in range(len(batch_data.get("image"))):
            step = engine.state.iteration
            region = batch_data.get("region")[bidx]
            region = region.item() if torch.is_tensor(region) else region

            image = batch_data["image"][bidx][0].detach().cpu().numpy()[
                np.newaxis]
            label = batch_data["label"][bidx].detach().cpu().numpy()
            pred = output_data["pred"][bidx].detach().cpu().numpy()
            dice = compute_meandice(
                y_pred=output_data["pred"][bidx][None].to(device),
                y=batch_data["label"][bidx][None].to(device),
                include_background=False,
            ).mean()

            if self.save_np:
                np.savez(
                    os.path.join(
                        self.output_dir,
                        "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(
                            tag, region, step, bidx, dice),
                    ),
                    image,
                    label,
                    pred,
                )

            if self.images and len(image.shape) == 3:
                img = make_grid(torch.from_numpy(
                    rescale_array(image, 0, 1)[0]))
                lab = make_grid(torch.from_numpy(
                    rescale_array(label, 0, 1)[0]))

                pos = rescale_array(
                    output_data["image"][bidx][1].detach().cpu().numpy()[
                        np.newaxis], 0, 1)[0]
                neg = rescale_array(
                    output_data["image"][bidx][2].detach().cpu().numpy()[
                        np.newaxis], 0, 1)[0]
                pre = make_grid(
                    torch.from_numpy(
                        np.array([rescale_array(pred, 0, 1)[0], pos, neg])))

                torchvision.utils.save_image(
                    tensor=[img, lab, pre],
                    nrow=3,
                    pad_value=2,
                    fp=os.path.join(
                        self.output_dir,
                        "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".
                        format(tag, region, step, bidx, dice),
                    ),
                )

            if self.images and len(image.shape) == 4:
                samples = {
                    "image": image[0],
                    "label": label[0],
                    "pred": pred[0]
                }
                for sample in samples:
                    img = np.moveaxis(samples[sample], -3, -1)
                    img = nib.Nifti1Image(img, np.eye(4))
                    nib.save(
                        img,
                        os.path.join(
                            self.output_dir,
                            "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(
                                tag, sample, step, bidx, dice)),
                    )
Ejemplo n.º 14
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)
def main():

    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(description='Run basic UNet with MONAI.')
    parser.add_argument('--config', dest='config', metavar='config', type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    sliding_window_validation = config_info['training']['sliding_window_validation']
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    # model saving
    out_model_dir = os.path.join(config_info['output']['out_model_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    max_nr_models_saved = config_info['output']['max_nr_models_saved']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)

    """
    Data Preparation
    """
    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
        RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True, num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_train_data = monai.utils.misc.first(train_loader)
    print("Training data tensor shapes")
    print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    check_valid_data = monai.utils.misc.first(val_loader)
    print("Validation data tensor shapes")
    print(check_valid_data['img'].shape, check_valid_data['seg'].shape)

    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()

    """
    Training loop
    """
    # start a typical PyTorch training
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    net.to(device)
    for epoch in range(nr_train_epochs):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch + 1, nr_train_epochs))
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device)
            opt.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            opt.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item()))
            writer_train.add_scalar('loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

        if (epoch + 1) % validation_every_n_epochs == 0:
            net.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                check_tot_validation = 0
                for val_data in val_loader:
                    check_tot_validation += 1
                    val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device)
                    if sliding_window_validation:
                        print('Running sliding window validation')
                        roi_size = (96, 96, 1)
                        val_outputs = sliding_window_inference(val_images, roi_size, batch_size_valid, net)
                        value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                                 to_onehot_y=False, add_sigmoid=True)
                        metric_count += len(value)
                        metric_sum += value.sum().item()
                    else:
                        print('Running 2D validation')
                        # compute validation
                        val_outputs = net(val_images)
                        value = 1.0 - loss_function(val_outputs, val_labels)
                        metric_count += 1
                        metric_sum += value.item()
                print("Total number of data in validation: %d" % check_tot_validation)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), os.path.join(out_model_dir, 'best_metric_model.pth'))
                    print('saved new best metric model')
                print("current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                      % (epoch + 1, metric, best_metric, best_metric_epoch))
                epoch_len = len(train_ds) // train_loader.batch_size
                writer_valid.add_scalar('loss', 1.0 - metric, epoch_len * epoch + step)
                writer_valid.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer_valid, index=0, tag='image')
                plot_2d_or_3d_image(val_labels, epoch + 1, writer_valid, index=0, tag='label')
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer_valid, index=0, tag='output')

    print('train completed, best_metric: %0.4f  at epoch: %d' % (best_metric, best_metric_epoch))
    writer_train.close()
    writer_valid.close()
Ejemplo n.º 16
0
def train_process(fast=False):
    epoch_num = 10
    val_interval = 1
    train_trans, val_trans = transformations()
    train_ds = Dataset(data=train_files, transform=train_trans)
    val_ds = Dataset(data=val_files, transform=val_trans)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n1 = 16
    model = UNet(dimensions=3,
                 in_channels=1,
                 out_channels=2,
                 channels=(n1 * 1, n1 * 2, n1 * 4, n1 * 8, n1 * 16),
                 strides=(2, 2, 2, 2)).to(device)
    loss_function = DiceLoss(to_onehot_y=True, softmax=True)
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = list()
    metric_values = list()

    for epoch in range(epoch_num):
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['image'].to(
                device), batch_data['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = val_data['image'].to(
                        device), val_data['label'].to(device)
                    val_outputs = model(val_inputs)
                    val_outputs = post_pred(val_outputs)
                    val_labels = post_label(val_labels)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    epochs_no_improve = 0
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    torch.save(model.state_dict(), 'sLUMRTL644.pth')
                else:
                    epochs_no_improve += 1

            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            )

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    return epoch_num, epoch_loss_values, metric_values, best_metrics_epochs_and_time
Ejemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser(description="training")
    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="checkpoint full path",
    )
    parser.add_argument(
        "--factor_ram_cost",
        default=0.0,
        type=float,
        help="factor to determine RAM cost in the searched architecture",
    )
    parser.add_argument(
        "--fold",
        action="store",
        required=True,
        help="fold index in N-fold cross-validation",
    )
    parser.add_argument(
        "--json",
        action="store",
        required=True,
        help="full path of .json file",
    )
    parser.add_argument(
        "--json_key",
        action="store",
        required=True,
        help="selected key in .json data list",
    )
    parser.add_argument(
        "--local_rank",
        required=int,
        help="local process rank",
    )
    parser.add_argument(
        "--num_folds",
        action="store",
        required=True,
        help="number of folds in cross-validation",
    )
    parser.add_argument(
        "--output_root",
        action="store",
        required=True,
        help="output root",
    )
    parser.add_argument(
        "--root",
        action="store",
        required=True,
        help="data root",
    )
    args = parser.parse_args()

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if not os.path.exists(args.output_root):
        os.makedirs(args.output_root, exist_ok=True)

    amp = True
    determ = True
    factor_ram_cost = args.factor_ram_cost
    fold = int(args.fold)
    input_channels = 1
    learning_rate = 0.025
    learning_rate_arch = 0.001
    learning_rate_milestones = np.array([0.4, 0.8])
    num_images_per_batch = 1
    num_epochs = 1430  # around 20k iteration
    num_epochs_per_validation = 100
    num_epochs_warmup = 715
    num_folds = int(args.num_folds)
    num_patches_per_image = 1
    num_sw_batch_size = 6
    output_classes = 3
    overlap_ratio = 0.625
    patch_size = (96, 96, 96)
    patch_size_valid = (96, 96, 96)
    spacing = [1.0, 1.0, 1.0]

    print("factor_ram_cost", factor_ram_cost)

    # deterministic training
    if determ:
        set_determinism(seed=0)

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    # dist.barrier()
    world_size = dist.get_world_size()

    with open(args.json, "r") as f:
        json_data = json.load(f)

    split = len(json_data[args.json_key]) // num_folds
    list_train = json_data[args.json_key][:(
        split * fold)] + json_data[args.json_key][(split * (fold + 1)):]
    list_valid = json_data[args.json_key][(split * fold):(split * (fold + 1))]

    # training data
    files = []
    for _i in range(len(list_train)):
        str_img = os.path.join(args.root, list_train[_i]["image"])
        str_seg = os.path.join(args.root, list_train[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    train_files = files

    random.shuffle(train_files)

    train_files_w = train_files[:len(train_files) // 2]
    train_files_w = partition_dataset(data=train_files_w,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_w:", len(train_files_w))

    train_files_a = train_files[len(train_files) // 2:]
    train_files_a = partition_dataset(data=train_files_a,
                                      shuffle=True,
                                      num_partitions=world_size,
                                      even_divisible=True)[dist.get_rank()]
    print("train_files_a:", len(train_files_a))

    # validation data
    files = []
    for _i in range(len(list_valid)):
        str_img = os.path.join(args.root, list_valid[_i]["image"])
        str_seg = os.path.join(args.root, list_valid[_i]["label"])

        if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
            continue

        files.append({"image": str_img, "label": str_seg})
    val_files = files
    val_files = partition_dataset(data=val_files,
                                  shuffle=False,
                                  num_partitions=world_size,
                                  even_divisible=False)[dist.get_rank()]
    print("val_files:", len(val_files))

    # network architecture
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)

    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float16, np.uint8)),
        CopyItemsd(keys=["label"], times=1, names=["label4crop"]),
        Lambdad(
            keys=["label4crop"],
            func=lambda x: np.concatenate(tuple([
                ndimage.binary_dilation(
                    (x == _k).astype(x.dtype), iterations=48).astype(x.dtype)
                for _k in range(output_classes)
            ]),
                                          axis=0),
            overwrite=True,
        ),
        EnsureTyped(keys=["image", "label"]),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        SpatialPadd(keys=["image", "label", "label4crop"],
                    spatial_size=patch_size,
                    mode=["reflect", "constant", "constant"]),
        RandCropByLabelClassesd(keys=["image", "label"],
                                label_key="label4crop",
                                num_classes=output_classes,
                                ratios=[
                                    1,
                                ] * output_classes,
                                spatial_size=patch_size,
                                num_samples=num_patches_per_image),
        Lambdad(keys=["label4crop"], func=lambda x: 0),
        RandRotated(keys=["image", "label"],
                    range_x=0.3,
                    range_y=0.3,
                    range_z=0.3,
                    mode=["bilinear", "nearest"],
                    prob=0.2),
        RandZoomd(keys=["image", "label"],
                  min_zoom=0.8,
                  max_zoom=1.2,
                  mode=["trilinear", "nearest"],
                  align_corners=[True, None],
                  prob=0.16),
        RandGaussianSmoothd(keys=["image"],
                            sigma_x=(0.5, 1.15),
                            sigma_y=(0.5, 1.15),
                            sigma_z=(0.5, 1.15),
                            prob=0.15),
        RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
        RandFlipd(keys=["image", "label"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=1, prob=0.5),
        RandFlipd(keys=["image", "label"], spatial_axis=2, prob=0.5),
        CastToTyped(keys=["image", "label"],
                    dtype=(torch.float32, torch.uint8)),
        ToTensord(keys=["image", "label"]),
    ])

    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"],
                 pixdim=spacing,
                 mode=("bilinear", "nearest"),
                 align_corners=(True, True)),
        CastToTyped(keys=["image"], dtype=(torch.float32)),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-87.0,
                             a_max=199.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
        EnsureTyped(keys=["image", "label"]),
        ToTensord(keys=["image", "label"])
    ])

    train_ds_a = monai.data.CacheDataset(data=train_files_a,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    train_ds_w = monai.data.CacheDataset(data=train_files_w,
                                         transform=train_transforms,
                                         cache_rate=1.0,
                                         num_workers=8)
    val_ds = monai.data.CacheDataset(data=val_files,
                                     transform=val_transforms,
                                     cache_rate=1.0,
                                     num_workers=2)

    # monai.data.Dataset can be used as alternatives when debugging or RAM space is limited.
    # train_ds_a = monai.data.Dataset(data=train_files_a, transform=train_transforms)
    # train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms)
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    train_loader_a = ThreadDataLoader(train_ds_a,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    train_loader_w = ThreadDataLoader(train_ds_w,
                                      num_workers=0,
                                      batch_size=num_images_per_batch,
                                      shuffle=True)
    val_loader = ThreadDataLoader(val_ds,
                                  num_workers=0,
                                  batch_size=1,
                                  shuffle=False)

    # DataLoader can be used as alternatives when ThreadDataLoader is less efficient.
    # train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
    # val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

    dints_space = monai.networks.nets.TopologySearch(
        channel_mul=0.5,
        num_blocks=12,
        num_depths=4,
        use_downsample=True,
        device=device,
    )

    model = monai.networks.nets.DiNTS(
        dints_space=dints_space,
        in_channels=input_channels,
        num_classes=output_classes,
        use_downsample=True,
    )

    model = model.to(device)

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    post_pred = Compose(
        [EnsureType(),
         AsDiscrete(argmax=True, to_onehot=output_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=output_classes)])

    # loss function
    loss_func = monai.losses.DiceCELoss(
        include_background=False,
        to_onehot_y=True,
        softmax=True,
        squared_pred=True,
        batch=True,
        smooth_nr=0.00001,
        smooth_dr=0.00001,
    )

    # optimizer
    optimizer = torch.optim.SGD(model.weight_parameters(),
                                lr=learning_rate * world_size,
                                momentum=0.9,
                                weight_decay=0.00004)
    arch_optimizer_a = torch.optim.Adam([dints_space.log_alpha_a],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)
    arch_optimizer_c = torch.optim.Adam([dints_space.log_alpha_c],
                                        lr=learning_rate_arch * world_size,
                                        betas=(0.5, 0.999),
                                        weight_decay=0.0)

    print()

    if torch.cuda.device_count() > 1:
        if dist.get_rank() == 0:
            print("Let's use", torch.cuda.device_count(), "GPUs!")

        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        find_unused_parameters=True)

    if args.checkpoint != None and os.path.isfile(args.checkpoint):
        print("[info] fine-tuning pre-trained checkpoint {0:s}".format(
            args.checkpoint))
        model.load_state_dict(torch.load(args.checkpoint, map_location=device))
        torch.cuda.empty_cache()
    else:
        print("[info] training from scratch")

    # amp
    if amp:
        from torch.cuda.amp import autocast, GradScaler
        scaler = GradScaler()
        if dist.get_rank() == 0:
            print("[info] amp enabled")

    # start a typical PyTorch training
    val_interval = num_epochs_per_validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    idx_iter = 0
    metric_values = list()

    if dist.get_rank() == 0:
        writer = SummaryWriter(
            log_dir=os.path.join(args.output_root, "Events"))

        with open(os.path.join(args.output_root, "accuracy_history.csv"),
                  "a") as f:
            f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")

    dataloader_a_iterator = iter(train_loader_a)

    start_time = time.time()
    for epoch in range(num_epochs):
        decay = 0.5**np.sum([
            (epoch - num_epochs_warmup) /
            (num_epochs - num_epochs_warmup) > learning_rate_milestones
        ])
        lr = learning_rate * decay
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        if dist.get_rank() == 0:
            print("-" * 10)
            print(f"epoch {epoch + 1}/{num_epochs}")
            print("learning rate is set to {}".format(lr))

        model.train()
        epoch_loss = 0
        loss_torch = torch.zeros(2, dtype=torch.float, device=device)
        epoch_loss_arch = 0
        loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
        step = 0

        for batch_data in train_loader_w:
            step += 1
            inputs, labels = batch_data["image"].to(
                device), batch_data["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = True
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = True
            dints_space.log_alpha_a.requires_grad = False
            dints_space.log_alpha_c.requires_grad = False

            optimizer.zero_grad()

            if amp:
                with autocast():
                    outputs = model(inputs)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs, dims=[1]),
                                         1 - labels)
                    else:
                        loss = loss_func(outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
                else:
                    loss = loss_func(outputs, labels)
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            loss_torch[0] += loss.item()
            loss_torch[1] += 1.0
            epoch_len = len(train_loader_w)
            idx_iter += 1

            if dist.get_rank() == 0:
                print("[{0}] ".format(str(datetime.now())[:19]) +
                      f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
                writer.add_scalar("train_loss", loss.item(),
                                  epoch_len * epoch + step)

            if epoch < num_epochs_warmup:
                continue

            try:
                sample_a = next(dataloader_a_iterator)
            except StopIteration:
                dataloader_a_iterator = iter(train_loader_a)
                sample_a = next(dataloader_a_iterator)
            inputs_search, labels_search = sample_a["image"].to(
                device), sample_a["label"].to(device)
            if world_size == 1:
                for _ in model.weight_parameters():
                    _.requires_grad = False
            else:
                for _ in model.module.weight_parameters():
                    _.requires_grad = False
            dints_space.log_alpha_a.requires_grad = True
            dints_space.log_alpha_c.requires_grad = True

            # linear increase topology and RAM loss
            entropy_alpha_c = torch.tensor(0.).to(device)
            entropy_alpha_a = torch.tensor(0.).to(device)
            ram_cost_full = torch.tensor(0.).to(device)
            ram_cost_usage = torch.tensor(0.).to(device)
            ram_cost_loss = torch.tensor(0.).to(device)
            topology_loss = torch.tensor(0.).to(device)

            probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
            entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
            entropy_alpha_c = -(F.softmax(dints_space.log_alpha_c, dim=-1) * \
                F.log_softmax(dints_space.log_alpha_c, dim=-1)).mean()
            topology_loss = dints_space.get_topology_entropy(probs_a)

            ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape,
                                                           full=True)
            ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
            ram_cost_loss = torch.abs(factor_ram_cost -
                                      ram_cost_usage / ram_cost_full)

            arch_optimizer_a.zero_grad()
            arch_optimizer_c.zero_grad()

            combination_weights = (epoch - num_epochs_warmup) / (
                num_epochs - num_epochs_warmup)

            if amp:
                with autocast():
                    outputs_search = model(inputs_search)
                    if output_classes == 2:
                        loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                         1 - labels_search)
                    else:
                        loss = loss_func(outputs_search, labels_search)

                    loss += combination_weights * ((entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                                    + 0.001 * topology_loss)

                scaler.scale(loss).backward()
                scaler.step(arch_optimizer_a)
                scaler.step(arch_optimizer_c)
                scaler.update()
            else:
                outputs_search = model(inputs_search)
                if output_classes == 2:
                    loss = loss_func(torch.flip(outputs_search, dims=[1]),
                                     1 - labels_search)
                else:
                    loss = loss_func(outputs_search, labels_search)

                loss += 1.0 * (combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss \
                                + 0.001 * topology_loss)

                loss.backward()
                arch_optimizer_a.step()
                arch_optimizer_c.step()

            epoch_loss_arch += loss.item()
            loss_torch_arch[0] += loss.item()
            loss_torch_arch[1] += 1.0

            if dist.get_rank() == 0:
                print(
                    "[{0}] ".format(str(datetime.now())[:19]) +
                    f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}")
                writer.add_scalar("train_loss_arch", loss.item(),
                                  epoch_len * epoch + step)

        # synchronizes all processes and reduce results
        dist.barrier()
        dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
        loss_torch = loss_torch.tolist()
        loss_torch_arch = loss_torch_arch.tolist()
        if dist.get_rank() == 0:
            loss_torch_epoch = loss_torch[0] / loss_torch[1]
            print(
                f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )

            if epoch >= num_epochs_warmup:
                loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
                print(
                    f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
                )

        if (epoch + 1) % val_interval == 0:
            torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                metric = torch.zeros((output_classes - 1) * 2,
                                     dtype=torch.float,
                                     device=device)
                metric_sum = 0.0
                metric_count = 0
                metric_mat = []
                val_images = None
                val_labels = None
                val_outputs = None

                _index = 0
                for val_data in val_loader:
                    val_images = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)

                    roi_size = patch_size_valid
                    sw_batch_size = num_sw_batch_size

                    if amp:
                        with torch.cuda.amp.autocast():
                            pred = sliding_window_inference(
                                val_images,
                                roi_size,
                                sw_batch_size,
                                lambda x: model(x),
                                mode="gaussian",
                                overlap=overlap_ratio,
                            )
                    else:
                        pred = sliding_window_inference(
                            val_images,
                            roi_size,
                            sw_batch_size,
                            lambda x: model(x),
                            mode="gaussian",
                            overlap=overlap_ratio,
                        )
                    val_outputs = pred

                    val_outputs = post_pred(val_outputs[0, ...])
                    val_outputs = val_outputs[None, ...]
                    val_labels = post_label(val_labels[0, ...])
                    val_labels = val_labels[None, ...]

                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False)

                    print(_index + 1, "/", len(val_loader), value)

                    metric_count += len(value)
                    metric_sum += value.sum().item()
                    metric_vals = value.cpu().numpy()
                    if len(metric_mat) == 0:
                        metric_mat = metric_vals
                    else:
                        metric_mat = np.concatenate((metric_mat, metric_vals),
                                                    axis=0)

                    for _c in range(output_classes - 1):
                        val0 = torch.nan_to_num(value[0, _c], nan=0.0)
                        val1 = 1.0 - torch.isnan(value[0, 0]).float()
                        metric[2 * _c] += val0 * val1
                        metric[2 * _c + 1] += val1

                    _index += 1

                # synchronizes all processes and reduce results
                dist.barrier()
                dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
                metric = metric.tolist()
                if dist.get_rank() == 0:
                    for _c in range(output_classes - 1):
                        print(
                            "evaluation metric - class {0:d}:".format(_c + 1),
                            metric[2 * _c] / metric[2 * _c + 1])
                    avg_metric = 0
                    for _c in range(output_classes - 1):
                        avg_metric += metric[2 * _c] / metric[2 * _c + 1]
                    avg_metric = avg_metric / float(output_classes - 1)
                    print("avg_metric", avg_metric)

                    if avg_metric > best_metric:
                        best_metric = avg_metric
                        best_metric_epoch = epoch + 1
                        best_metric_iterations = idx_iter

                    node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d = dints_space.decode(
                    )
                    torch.save(
                        {
                            "node_a": node_a_d,
                            "arch_code_a": arch_code_a_d,
                            "arch_code_a_max": arch_code_a_max_d,
                            "arch_code_c": arch_code_c_d,
                            "iter_num": idx_iter,
                            "epochs": epoch + 1,
                            "best_dsc": best_metric,
                            "best_path": best_metric_iterations,
                        },
                        os.path.join(args.output_root,
                                     "search_code_" + str(idx_iter) + ".pth"),
                    )
                    print("saved new best metric model")

                    dict_file = {}
                    dict_file["best_avg_dice_score"] = float(best_metric)
                    dict_file["best_avg_dice_score_epoch"] = int(
                        best_metric_epoch)
                    dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
                    with open(os.path.join(args.output_root, "progress.yaml"),
                              "w") as out_file:
                        documents = yaml.dump(dict_file, stream=out_file)

                    print(
                        "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                        .format(epoch + 1, avg_metric, best_metric,
                                best_metric_epoch))

                    current_time = time.time()
                    elapsed_time = (current_time - start_time) / 60.0
                    with open(
                            os.path.join(args.output_root,
                                         "accuracy_history.csv"), "a") as f:
                        f.write(
                            "{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n"
                            .format(epoch + 1, avg_metric, loss_torch_epoch,
                                    lr, elapsed_time, idx_iter))

                dist.barrier()

            torch.cuda.empty_cache()

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )

    if dist.get_rank() == 0:
        writer.close()

    dist.destroy_process_group()

    return
Ejemplo n.º 18
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    device = torch.device("cuda:0")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs,
                                     y=val_labels,
                                     include_background=True,
                                     to_onehot_y=False,
                                     add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(
                val_outputs, {
                    "filename_or_obj": val_data["img.filename_or_obj"],
                    "affine": val_data["img.affine"]
                })
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
    shutil.rmtree(tempdir)
Ejemplo n.º 19
0
def train():
    """

    :return:
    """
    print('Model training started')
    set_determinism(seed=0)

    epoch_num = params['nb_epoch']
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        writer.add_scalar("epoch_loss", epoch_loss, epoch + 1)

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = model(val_inputs)
                    value = compute_meandice(
                        y_pred=val_outputs,
                        y=val_labels,
                        include_background=False,
                        to_onehot_y=True,
                        mutually_exclusive=True,
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model,
                        os.path.join('saved_models', "best_metric_model.pth"))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last validation subject as GIF in TensorBoard with the corresponding image, label and pred
                val_pred = torch.argmax(val_outputs, dim=1, keepdim=True)
                summary_img = torch.cat((val_inputs, val_labels, val_pred),
                                        dim=3)
                plot_2d_or_3d_image(summary_img,
                                    epoch + 1,
                                    writer,
                                    tag='last_val_subject')

        # Model checkpointing
        if (epoch + 1) % 20 == 0:
            torch.save(
                model,
                os.path.join('saved_models',
                             params['f_name'] + '_' + str(epoch + 1) + '.pth'))

    print(
        f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
    )
    writer.close()
Ejemplo n.º 20
0
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size,
                                                       sw_batch_size, model)
                val_outputs = post_pred(val_outputs)

                largest = KeepLargestConnectedComponent(applied_labels=[1])

                val_labels = post_label(val_labels)
                value = compute_meandice(
                    y_pred=val_outputs,
                    y=val_labels,
                    include_background=False,
                )
                metric_count += len(value)
                metric_sum += value.sum().item()
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(),
                           os.path.join(out_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
Ejemplo n.º 21
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(
        tempdir))
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'img%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'img*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
    train_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        RandCropByPosNegLabeld(keys=['img', 'seg'],
                               label_key='seg',
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=['img', 'seg'])
    ])
    val_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        ToTensord(keys=['img', 'seg'])
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data['img'].shape, check_data['seg'].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, 5))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(
                device), batch_data['seg'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len,
                                                     loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data['img'].to(
                        device), val_data['seg'].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=True,
                                             to_onehot_y=False,
                                             add_sigmoid=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), 'best_metric_model.pth')
                    print('saved new best metric model')
                print(
                    'current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='image')
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='label')
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='output')
    shutil.rmtree(tempdir)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(
        best_metric, best_metric_epoch))
    writer.close()
Ejemplo n.º 22
0
def fit(model, train_ds, val_ds, batch_size, epoch_num,
        loss_function, optimizer, device, root_dir,
        callbacks=None, verbose=1):
    # train_loader = torch.utils.data.DataLoader(
    #     train_ds, batch_size=batch_size, shuffle=True, num_workers=2
    # )
    # val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2)

    train_loader = monai.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = monai.data.DataLoader(val_ds, batch_size=batch_size, num_workers=2)

    # tensorboard
    writer = SummaryWriter()

    val_interval = 1  # do validation for every epoch,
    best_metric = float("-inf")
    best_metric_epoch = float("-inf")
    epoch_loss_values = list()
    metric_values = list()
    epoch_times = list()
    total_start = time.time()
    for epoch in range(epoch_num):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step_start = time.time()
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            writer.add_scalar('Loss/train', loss.item(), epoch * len(train_ds) + step - 1)
            print(
                f"Epoch: [{epoch + 1}], [{step}/{len(train_ds) // train_loader.batch_size}], train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f} "
            )  #  ETA: 0:01:18
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = model(val_inputs)
                    value = compute_meandice(val_outputs, val_inputs, sigmoid=True, logit_thresh=0.5)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                writer.add_scalar('DiceMetric/val', metric, (epoch + 1) * len(train_ds))
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"),
                    )
                    # torch.save({
                    #             'epoch': EPOCH,
                    #             'model_state_dict': net.state_dict(),
                    #             'optimizer_state_dict': optimizer.state_dict(),
                    #             'loss': LOSS,
                    #             }, PATH)
                    print("saved new best metric model")

                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}",
                    f" best mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}",
                ),
        print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
        epoch_times.append(time.time() - epoch_start)

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}",
        f" total time: {(time.time() - total_start):.4f}"
    ),
    return (
        epoch_num,
        time.time() - total_start,
        epoch_loss_values,
        metric_values,
        epoch_times,
    )
Ejemplo n.º 23
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    train_imtrans = Compose(
        [
            ScaleIntensity(),
            AddChannel(),
            RandSpatialCrop((96, 96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 2)),
            ToTensor(),
        ]
    )
    train_segtrans = Compose(
        [
            AddChannel(),
            RandSpatialCrop((96, 96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 2)),
            ToTensor(),
        ]
    )
    val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    val_segtrans = Compose([AddChannel(), ToTensor()])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(
                        y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    shutil.rmtree(tempdir)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
Ejemplo n.º 24
0
def train(model,
          train_loader,
          val_loader,
          loss_function,
          optimizer,
          output_dir,
          device,
          epoch_num=600,
          val_interval=2):
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    for epoch in range(epoch_num):
        print('-' * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['image'].to(
                device), batch_data['label'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}"
            )
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_inputs, val_labels = val_data['image'].to(
                        device), val_data['label'].to(device)
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_inputs, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=False,
                                             to_onehot_y=True,
                                             mutually_exclusive=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(),
                        os.path.join(output_dir, 'best_metric_model.pth'))
                    print('saved new best metric model')
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
    metric_output = os.path.join(output_dir, "training_metrics.png")
    plot_metrics(epoch_loss_values,
                 metric_values,
                 val_interval,
                 output_path=metric_output)
Ejemplo n.º 25
0
def run_training_test(root_dir,
                      device=torch.device("cuda:0"),
                      cachedataset=False):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 interp_order=["bilinear", "nearest"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        Spacingd(keys=["img", "seg"],
                 pixdim=[1.2, 0.8, 0.7],
                 interp_order=["bilinear", "nearest"]),
        ScaleIntensityd(keys=["img", "seg"]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a training data loader
    if cachedataset:
        train_ds = monai.data.CacheDataset(data=train_files,
                                           transform=train_transforms,
                                           cache_rate=0.8)
    else:
        train_ds = monai.data.Dataset(data=train_files,
                                      transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=True,
                                             to_onehot_y=False,
                                             sigmoid=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")
    print(
        f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}"
    )
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
Ejemplo n.º 26
0
def train_epoch(data_loader,
                model,
                alpha,
                criterion1,
                criterion2,
                optimizer=None,
                mode_train=True):
    """
    Inputs:
    : data_loader: training or validation sets in dataset pytorch format
    : model (class): unet architecure
    : alpha (int): alpgha value for the boundary loss
    : criterion1: loss function 1 (in our case Generalized Dice Loss)
    : criterion2: loss function 2 ( in our case Boundary loss)
    : optimizer (class): define optimizer (ie. adam)
    : mode_train (bool): True is train, False is validation

    """
    total_batch = len(data_loader)
    batch_loss = average_metrics()
    batch_dice = average_metrics()
    batch_hausdorff = average_metrics()
    post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
    post_label = AsDiscrete(to_onehot=True, n_classes=2)

    if mode_train:
        model.train()
    else:
        model.eval()

    for batch_idx, (data, y) in enumerate(data_loader):
        data = data.to(device)
        data = data.type(torch.cuda.FloatTensor)
        y = y.to(device)

        # Reset gradients
        if mode_train:
            optimizer.zero_grad()

        # Get prediction
        out = model(data.to(device))

        # Evaluation
        outputs = post_pred(out.to(device))

        labels = post_label(y.to(device))

        dice_val = compute_meandice(y_pred=outputs,
                                    y=labels,
                                    include_background=False)
        hausdorff_val = compute_hausdorff_distance(y_pred=outputs,
                                                   y=labels,
                                                   distance_metric='euclidean')

        if math.isnan(dice_val.item()) or math.isnan(hausdorff_val.item()):
            pass
        else:
            batch_dice.update(dice_val.item())
            batch_hausdorff.update(hausdorff_val.item())

        # Losses
        if criterion1 and criterion2 == None:
            # Get REGION loss
            region_loss = criterion1(out.to(device), y.to(device))

            # Backpropagate error
            region_loss.backward()
            # Update loss
            batch_loss.update(region_loss.item())

        elif criterion1 == None and criterion2:

            # Get CONTOUR loss
            out_probs = softmax(out.to(device), dim=1)
            contour_loss = criterion2(out_probs.to(device), dy.to(device))

            # Backpropagate error
            contour_loss.backward()
            # Update loss
            batch_loss.update(contour_loss.item())

        else:
            # Get REGION loss
            region_loss = criterion1(out.to(device), y.to(device))

            # Get CONTOUR loss
            out_probs = softmax(out.to(device), dim=1)
            contour_loss = criterion2(out_probs.to(device), y.to(device))
            # Combination both losses
            loss = region_loss + alpha * contour_loss
            # Backpropagate error
            loss.backward()
            # Update loss
            # Update loss
            batch_loss.update(loss.item())

        # Optimize
        if mode_train:
            optimizer.step()

        # Log
        if (batch_idx + 1) % opt.verbose == 0 and mode_train:
            if criterion1 and criterion2 == None:
                print(
                    f'Iteration {(batch_idx + 1)}/{total_batch} - GD Loss: {batch_loss.val} - Dice: {batch_dice.val} - Hausdorff:{batch_hausdorff.val}'
                )
            elif criterion1 == None and criterion2:
                print(
                    f'Iteration {(batch_idx + 1)}/{total_batch} - B Loss: {batch_loss.val} - Dice: {batch_dice.val} -  Hausdorff:{batch_hausdorff.val}'
                )
            else:
                print(
                    f'Iteration {(batch_idx + 1)}/{total_batch} - GD & B Loss: {batch_loss.val} - Dice: {batch_dice.val} - Hausdorff:{batch_hausdorff.val}'
                )

    return batch_loss.avg, batch_dice.avg, batch_hausdorff.avg