예제 #1
0
def build_net():

    from init import Options
    opt = Options().parse()
    from monai.networks.layers import Norm
    from monai.networks.layers.factories import split_args
    act_type, args = split_args("RELU")

    # # create Unet
    # Unet = monai.networks.nets.UNet(
    #     dimensions=3,
    #     in_channels=opt.in_channels,
    #     out_channels=opt.out_channels,
    #     channels=(64, 128, 256, 512, 1024),
    #     strides=(2, 2, 2, 2),
    #     act=act_type,
    #     num_res_units=3,
    #     dropout=0.2,
    #     norm=Norm.BATCH,
    #
    # )

    # create nn-Unet
    if opt.resolution is None:
        sizes, spacings = opt.patch_size, opt.spacing
    else:
        sizes, spacings = opt.patch_size, opt.resolution

    strides, kernels = [], []

    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    nn_Unet = monai.networks.nets.DynUNet(
        spatial_dims=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        res_block=True,
    )

    init_weights(nn_Unet, init_type='normal')

    return nn_Unet
def build_net():

    from init import Options
    opt = Options().parse()
    from monai.networks.layers import Norm

    # create nn-Unet
    if opt.resolution is None:
        sizes, spacings = opt.patch_size, opt.spacing
    else:
        sizes, spacings = opt.patch_size, opt.resolution

    strides, kernels = [], []

    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # # create Unet

    nn_Unet = monai.networks.nets.DynUNet(
        spatial_dims=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        res_block=True,
    )

    init_weights(nn_Unet, init_type='normal')

    return nn_Unet
예제 #3
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation

    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):

        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list
    # Need to concatenate multiple channels here if you want multichannel segmentation
    # Check other examples on Monai webpage.

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'],  gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=torch.cuda.is_available())

    # # try to use all the available GPUs
    # devices = get_devices_spec(None)

    # build the network
    net = build_net()
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    # loss_function = monai.losses.DiceLoss(sigmoid=True)
    # loss_function = monai.losses.TverskyLoss(sigmoid=True, alpha=0.3, beta=0.7)
    loss_function = monai.losses.DiceCELoss(sigmoid=True)

    optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
    net_scheduler = get_scheduler(optim, opt)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    metric_sum = 0.0
                    metric_count = 0
                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = post_trans(val_outputs)
                        value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                        metric_count += len(value)
                        metric_sum += value.item() * len(value)
                    metric = metric_sum / metric_count
                    metric_values.append(metric)
                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
        # norm=Norm.BATCH,
    )

    init_weights(nn_Unet, init_type='normal')

    return nn_Unet


if __name__ == '__main__':
    import time
    import torch
    from torch.autograd import Variable
    from torchsummaryX import summary
    from torch.nn import init

    opt = Options().parse()

    torch.cuda.set_device(0)
    network = build_net()
    net = network.cuda().eval()

    data = Variable(torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()

    out = net(data)

    torch.onnx.export(net, data, "Unet_model_graph.onnx")

    summary(net,data)
    print("out size: {}".format(out.size()))

def build_net():
    from init import Options
    opt = Options().parse()
    from monai.networks.layers import Norm
    from monai.networks.layers.factories import split_args
    act_type, args = split_args("RELU")

    # create Unet
    Unet = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2),
        act=act_type,
        num_res_units=3,
        dropout=0.2,
        norm=Norm.BATCH,

    )

    class UNet_David(Module):
        # __                            __
        #  1|__   ________________   __|1
        #     2|__  ____________  __|2
        #        3|__  ______  __|3
        #           4|__ __ __|4
        # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity

        def __init__(self, feat_channels=[32, 64, 128, 256, 512], residual='conv'):
            # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals

            super(UNet_David, self).__init__()

            class Conv3D_Block(Module):

                def __init__(self, inp_feat, out_feat, kernel=3, stride=1, padding=1, residual=None):

                    super(Conv3D_Block, self).__init__()

                    self.conv1 = Sequential(
                        Conv3d(inp_feat, out_feat, kernel_size=kernel,
                               stride=stride, padding=padding, bias=True),
                        BatchNorm3d(out_feat),
                        ReLU())

                    self.conv2 = Sequential(
                        Conv3d(out_feat, out_feat, kernel_size=kernel,
                               stride=stride, padding=padding, bias=True),
                        BatchNorm3d(out_feat),
                        ReLU())

                    self.residual = residual

                    if self.residual is not None:
                        self.residual_upsampler = Conv3d(inp_feat, out_feat, kernel_size=1, bias=False)

                def forward(self, x):

                    res = x

                    if not self.residual:
                        return self.conv2(self.conv1(x))
                    else:
                        return self.conv2(self.conv1(x)) + self.residual_upsampler(res)

            class Deconv3D_Block(Module):

                def __init__(self, inp_feat, out_feat, kernel=3, stride=2, padding=1):
                    super(Deconv3D_Block, self).__init__()

                    self.deconv = Sequential(
                        ConvTranspose3d(inp_feat, out_feat, kernel_size=(kernel, kernel, kernel),
                                        stride=(stride, stride, stride), padding=(padding, padding, padding),
                                        output_padding=1, bias=True),
                        ReLU())

                def forward(self, x):
                    return self.deconv(x)

            class ChannelPool3d(AvgPool1d):

                def __init__(self, kernel_size, stride, padding):
                    super(ChannelPool3d, self).__init__(kernel_size, stride, padding)
                    self.pool_1d = AvgPool1d(self.kernel_size, self.stride, self.padding, self.ceil_mode)

                def forward(self, inp):
                    n, c, d, w, h = inp.size()
                    inp = inp.view(n, c, d * w * h).permute(0, 2, 1)
                    pooled = self.pool_1d(inp)
                    c = int(c / self.kernel_size[0])
                    return inp.view(n, c, d, w, h)

            # Encoder downsamplers
            self.pool1 = MaxPool3d((2, 2, 2))
            self.pool2 = MaxPool3d((2, 2, 2))
            self.pool3 = MaxPool3d((2, 2, 2))
            self.pool4 = MaxPool3d((2, 2, 2))

            # Encoder convolutions
            self.conv_blk1 = Conv3D_Block(opt.in_channels, feat_channels[0], residual=residual)
            self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual)
            self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual)
            self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual)
            self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual)

            # Decoder convolutions
            self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual)
            self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual)
            self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual)
            self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual)

            # Decoder upsamplers
            self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3])
            self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2])
            self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1])
            self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0])

            # Final 1*1 Conv Segmentation map
            self.one_conv = Conv3d(feat_channels[0], opt.out_channels, kernel_size=1, stride=1, padding=0, bias=True)

        def forward(self, x):
            # Encoder part

            x1 = self.conv_blk1(x)

            x_low1 = self.pool1(x1)
            x2 = self.conv_blk2(x_low1)

            x_low2 = self.pool2(x2)
            x3 = self.conv_blk3(x_low2)

            x_low3 = self.pool3(x3)
            x4 = self.conv_blk4(x_low3)

            x_low4 = self.pool4(x4)
            base = self.conv_blk5(x_low4)

            # Decoder part

            d4 = torch.cat([self.deconv_blk4(base), x4], dim=1)
            d_high4 = self.dec_conv_blk4(d4)

            d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1)
            d_high3 = self.dec_conv_blk3(d3)
            d_high3 = Dropout3d(p=0.5)(d_high3)

            d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1)
            d_high2 = self.dec_conv_blk2(d2)
            d_high2 = Dropout3d(p=0.5)(d_high2)

            d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1)
            d_high1 = self.dec_conv_blk1(d1)

            seg = self.one_conv(d_high1)

            return seg

    # create HighResNet
    HighResNet = monai.networks.nets.HighResNet(
        spatial_dims=3,
        in_channels=opt.in_channels,
        out_channels=opt.out_channels,
    )

    if opt.net == 'Unet_Monai':
        network = Unet
    elif opt.net == 'HighResNet':
        network = HighResNet
    elif opt.net == 'Unet_David':
        network = UNet_David(residual='pool')
    else:
        raise NotImplementedError

    init_weights(network, init_type='normal')

    return network
예제 #6
0
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    set_determinism(seed=0)
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    device = torch.device(opt.gpu_id)

    # ------- Data loader creation ----------

    # images
    images = sorted(glob(os.path.join(opt.images_folder, 'image*.nii')))
    segs = sorted(glob(os.path.join(opt.labels_folder, 'label*.nii')))

    train_files = []
    val_files = []

    for i in range(opt.models_ensemble):
        train_files.append([{
            "image": img,
            "label": seg
        } for img, seg in zip(
            images[:(opt.split_val * i)] +
            images[(opt.split_val *
                    (i + 1)):(len(images) -
                              opt.split_val)], segs[:(opt.split_val * i)] +
            segs[(opt.split_val * (i + 1)):(len(images) - opt.split_val)])])
        val_files.append([{
            "image": img,
            "label": seg
        } for img, seg in zip(
            images[(opt.split_val * i):(opt.split_val *
                                        (i + 1))], segs[(opt.split_val *
                                                         i):(opt.split_val *
                                                             (i + 1))])])

    test_files = [{
        "image": img,
        "label": seg
    } for img, seg in zip(images[(len(images) -
                                  opt.split_test):len(images)], segs[(
                                      len(images) -
                                      opt.split_test):len(images)])]

    # ----------- Transforms list --------------

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'],
                     pixdim=opt.resolution,
                     mode=('bilinear', 'nearest')),
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.1, spatial_axis=2),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36),
                        padding_mode="zeros"),
            RandAffined(keys=['image', 'label'],
                        mode=('bilinear', 'nearest'),
                        prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36),
                        padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'],
                           mode=('bilinear', 'nearest'),
                           prob=0.1,
                           sigma_range=(5, 8),
                           magnitude_range=(100, 200),
                           scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'],
                               prob=0.1,
                               mean=np.random.uniform(0, 0.5),
                               std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'],
                                offsets=np.random.uniform(0, 0.3),
                                prob=0.1),
            RandSpatialCropd(keys=['image', 'label'],
                             roi_size=opt.patch_size,
                             random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            NormalizeIntensityd(keys=['image']),
            ScaleIntensityd(keys=['image']),
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # ---------- Creation of DataLoaders -------------

    train_dss = [
        CacheDataset(data=train_files[i], transform=train_transforms)
        for i in range(opt.models_ensemble)
    ]
    train_loaders = [
        DataLoader(train_dss[i],
                   batch_size=opt.batch_size,
                   shuffle=True,
                   num_workers=opt.workers,
                   pin_memory=torch.cuda.is_available())
        for i in range(opt.models_ensemble)
    ]

    val_dss = [
        CacheDataset(data=val_files[i], transform=val_transforms)
        for i in range(opt.models_ensemble)
    ]
    val_loaders = [
        DataLoader(val_dss[i],
                   batch_size=1,
                   num_workers=opt.workers,
                   pin_memory=torch.cuda.is_available())
        for i in range(opt.models_ensemble)
    ]

    test_ds = CacheDataset(data=test_files, transform=val_transforms)
    test_loader = DataLoader(test_ds,
                             batch_size=1,
                             num_workers=opt.workers,
                             pin_memory=torch.cuda.is_available())

    def train(index):

        # ---------- Build the nn-Unet network ------------

        if opt.resolution is None:
            sizes, spacings = opt.patch_size, opt.spacing
        else:
            sizes, spacings = opt.patch_size, opt.resolution

        strides, kernels = [], []

        while True:
            spacing_ratio = [sp / min(spacings) for sp in spacings]
            stride = [
                2 if ratio <= 2 and size >= 8 else 1
                for (ratio, size) in zip(spacing_ratio, sizes)
            ]
            kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
            if all(s == 1 for s in stride):
                break
            sizes = [i / j for i, j in zip(sizes, stride)]
            spacings = [i * j for i, j in zip(spacings, stride)]
            kernels.append(kernel)
            strides.append(stride)
        strides.insert(0, len(spacings) * [1])
        kernels.append(len(spacings) * [3])

        net = monai.networks.nets.DynUNet(
            spatial_dims=3,
            in_channels=opt.in_channels,
            out_channels=opt.out_channels,
            kernel_size=kernels,
            strides=strides,
            upsample_kernel_size=strides[1:],
            res_block=True,
            # act=act_type,
            # norm=Norm.BATCH,
        ).to(device)

        from torch.autograd import Variable
        from torchsummaryX import summary

        data = Variable(
            torch.randn(int(opt.batch_size), int(opt.in_channels),
                        int(opt.patch_size[0]), int(opt.patch_size[1]),
                        int(opt.patch_size[2]))).cuda()

        out = net(data)
        summary(net, data)
        print("out size: {}".format(out.size()))

        # if opt.preload is not None:
        #     net.load_state_dict(torch.load(opt.preload))

        # ---------- ------------------------ ------------

        optim = torch.optim.Adam(net.parameters(), lr=opt.lr)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs)**0.9)

        loss_function = monai.losses.DiceCELoss(sigmoid=True)

        val_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1])
        ])

        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={"net": net},
                            save_key_metric=True),
        ]

        evaluator = SupervisedEvaluator(
            device=device,
            val_data_loader=val_loaders[index],
            network=net,
            inferer=SlidingWindowInferer(roi_size=opt.patch_size,
                                         sw_batch_size=opt.batch_size,
                                         overlap=0.5),
            post_transform=val_post_transforms,
            key_val_metric={
                "val_mean_dice":
                MeanDice(
                    include_background=True,
                    output_transform=lambda x: (x["pred"], x["label"]),
                )
            },
            val_handlers=val_handlers)

        train_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
            AsDiscreted(keys="pred", threshold_values=True),
            # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        ])

        train_handlers = [
            ValidationHandler(validator=evaluator,
                              interval=5,
                              epoch_level=True),
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir="./runs/",
                            save_dict={
                                "net": net,
                                "opt": optim
                            },
                            save_final=True,
                            epoch_level=True),
        ]

        trainer = SupervisedTrainer(
            device=device,
            max_epochs=opt.epochs,
            train_data_loader=train_loaders[index],
            network=net,
            optimizer=optim,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=train_post_transforms,
            amp=False,
            train_handlers=train_handlers,
        )
        trainer.run()
        return net

    models = [train(i) for i in range(opt.models_ensemble)]

    # -------- Test the models ---------

    def ensemble_evaluate(post_transforms, models):

        evaluator = EnsembleEvaluator(
            device=device,
            val_data_loader=test_loader,
            pred_keys=opt.pred_keys,
            networks=models,
            inferer=SlidingWindowInferer(roi_size=opt.patch_size,
                                         sw_batch_size=opt.batch_size,
                                         overlap=0.5),
            post_transform=post_transforms,
            key_val_metric={
                "test_mean_dice":
                MeanDice(
                    include_background=True,
                    output_transform=lambda x: (x["pred"], x["label"]),
                )
            },
        )
        evaluator.run()

    mean_post_transforms = Compose([
        MeanEnsembled(
            keys=opt.pred_keys,
            output_key="pred",
            # in this particular example, we use validation metrics as weights
            weights=opt.weights_models,
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold_values=True),
        # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1])
    ])

    print('Results from MeanEnsembled:')
    ensemble_evaluate(mean_post_transforms, models)

    vote_post_transforms = Compose([
        Activationsd(keys=opt.pred_keys, sigmoid=True),
        # transform data into discrete before voting
        AsDiscreted(keys=opt.pred_keys, threshold_values=True),
        # KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
        VoteEnsembled(keys=opt.pred_keys, output_key="pred"),
    ])

    print('Results from VoteEnsembled:')
    ensemble_evaluate(vote_post_transforms, models)
def main():
    opt = Options().parse()
    # monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # check gpus
    if opt.gpu_ids != '-1':
        num_gpus = len(opt.gpu_ids.split(','))
    else:
        num_gpus = 0
    print('number of GPU:', num_gpus)

    # Data loader creation
    # train images
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))

    # validation images
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))

    # test images
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))

    # augment the data list for training
    for i in range(int(opt.increase_factor_data)):
    
        train_images.extend(train_images)
        train_segs.extend(train_segs)

    print('Number of training patches per epoch:', len(train_images))
    print('Number of training images per epoch:', len(train_images_for_dice))
    print('Number of validation images per epoch:', len(val_images))
    print('Number of test images per epoch:', len(test_images))

    # Creation of data directories for data_loader

    train_dicts = [{'image': image_name, 'label': label_name}
                  for image_name, label_name in zip(train_images, train_segs)]

    train_dice_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]

    val_dicts = [{'image': image_name, 'label': label_name}
                   for image_name, label_name in zip(val_images, val_segs)]

    test_dicts = [{'image': image_name, 'label': label_name}
                 for image_name, label_name in zip(test_images, test_segs)]

    # Transforms list

    if opt.resolution is not None:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]
    else:
        train_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                          # augmentation
            ScaleIntensityd(keys=['image']),                                              # intensity

            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
                           padding_mode="zeros"),
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
            ToTensord(keys=['image', 'label'])
        ]

        val_transforms = [
            LoadImaged(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground

            NormalizeIntensityd(keys=['image']),                                      # intensity
            ScaleIntensityd(keys=['image']),

            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
            ToTensord(keys=['image', 'label'])
        ]

    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)

    # create a training data loader
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False)

    # create a training_dice data loader
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # create a validation data loader
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)

    # build the network
    if opt.network is 'nnunet':
        net = build_net()  # nn build_net
    elif opt.network is 'unetr':
        net = build_UNETR() # UneTR
    net.cuda()

    if num_gpus > 1:
        net = torch.nn.DataParallel(net)

    if opt.preload is not None:
        net.load_state_dict(torch.load(opt.preload))

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

    loss_function = monai.losses.DiceCELoss(sigmoid=True)
    torch.backends.cudnn.benchmark = opt.benchmark


    if opt.network is 'nnunet':

        optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
        net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)

    elif opt.network is 'unetr':

        optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5)

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    writer = SummaryWriter()
    for epoch in range(opt.epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{opt.epochs}")
        net.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
            optim.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optim.step()
            epoch_loss += loss.item()
            epoch_len = len(check_train) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        if opt.network is 'nnunet':
            update_learning_rate(net_scheduler, optim)

        if (epoch + 1) % val_interval == 0:
            net.eval()
            with torch.no_grad():

                def plot_dice(images_loader):

                    val_images = None
                    val_labels = None
                    val_outputs = None
                    for data in images_loader:
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
                        roi_size = opt.patch_size
                        sw_batch_size = 4
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                        dice_metric(y_pred=val_outputs, y=val_labels)

                    # aggregate the final mean dice result
                    metric = dice_metric.aggregate().item()
                    # reset the status for next validation round
                    dice_metric.reset()

                    return metric, val_images, val_labels, val_outputs

                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)

                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(net.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")

                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)

                # Logger bar
                print(
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
                    )
                )

                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
                writer.add_scalar("Validation_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()