コード例 #1
0
ファイル: test_matshow3d.py プロジェクト: juampatronics/MONAI
    def test_samples(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
        keys = "image"
        xforms = Compose(
            [
                LoadImaged(keys=keys),
                AddChanneld(keys=keys),
                ScaleIntensityd(keys=keys),
                RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10),
            ]
        )
        image_path = os.path.join(testing_dir, "anatomical.nii")
        xforms.set_random_state(0)
        ims = xforms({keys: image_path})
        fig, mat = matshow3d(
            [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False
        )
        self.assertTrue(mat.dtype == np.float32)

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/matshow3d_patch_test.png"
            fig.savefig(tempimg)
            comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True)
            if comp:
                print("not none comp: ", comp)  # matplotlib 3.2.2
                np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3)
            else:
                self.assertIsNone(comp, f"value of comp={comp}")  # None indicates test passed
コード例 #2
0
class Lungs(Dataset, Randomizable):
    def __init__(self, dicom_folders):
        self.dicom_folders = dicom_folders
        self.transforms = get_validation_augmentation()
        self.preprocessing = get_preprocessing(
            functools.partial(preprocess_input, **formatted_settings))
        self.transform3d = Compose(
            [ScaleIntensity(),
             Resize((160, 160, 160)),
             ToTensor()])

    def __len__(self):
        return len(self.dicom_folders)

    def randomize(self) -> None:
        MAX_SEED = np.iinfo(np.uint32).max + 1
        self._seed = self.R.randint(MAX_SEED, dtype="uint32")

    def get(self, i):
        s = time.time()
        data = load_dicom_array(self.dicom_folders[i])
        image, files = data
        image_lung = np.expand_dims(window(image, WL=-600, WW=1500), axis=3)
        image_mediastinal = np.expand_dims(window(image, WL=40, WW=400),
                                           axis=3)
        image_pe_specific = np.expand_dims(window(image, WL=100, WW=700),
                                           axis=3)
        image = np.concatenate(
            [image_mediastinal, image_pe_specific, image_lung], axis=3)
        rat = MAX_LENGTH / np.max(image.shape[1:])
        names = [row.split(".dcm")[0].split("/")[-3:] for row in files]
        images = []
        for img in image:
            if self.transforms:
                img = self.transforms(image=img)['image']
            if self.preprocessing:
                img = self.preprocessing(image=img)['image']
            images.append(img)
        images = np.array(images)
        img = images[:, ::-1].transpose(1, 2, 3, 0)
        if self.transform3d is not None:
            if isinstance(self.transform3d, Randomizable):
                self.transform3d.set_random_state(seed=self._seed)
            img = apply_transform(self.transform3d, img)

        return torch.from_numpy(images), names, img

    def __getitem__(self, i):
        self.randomize()
        try:
            return self.get(i)
        except Exception as e:
            print(e)
            return None, None, None
コード例 #3
0
    def test_random_compose(self):
        class _Acc(Randomizable):
            self.rand = 0.0

            def randomize(self, data=None):
                self.rand = self.R.rand()

            def __call__(self, data):
                self.randomize()
                return self.rand + data

        c = Compose([_Acc(), _Acc()])
        self.assertNotAlmostEqual(c(0), c(0))
        c.set_random_state(123)
        self.assertAlmostEqual(c(1), 1.61381597)
        c.set_random_state(223)
        c.randomize()
        self.assertAlmostEqual(c(1), 1.90734751)
コード例 #4
0
    def test_random_compose(self):
        class _Acc(Randomizable):
            self.rand = 0.0

            def randomize(self):
                self.rand = self.R.rand()

            def __call__(self, data):
                self.randomize()
                return self.rand + data

        c = Compose([_Acc(), _Acc()])
        self.assertNotAlmostEqual(c(0), c(0))
        c.set_random_state(123)
        self.assertAlmostEqual(c(1), 2.39293837)
        c.set_random_state(223)
        c.randomize()
        self.assertAlmostEqual(c(1), 2.57673391)
コード例 #5
0
    def test_data_loader(self):
        xform_1 = Compose([_RandXform()])
        train_ds = Dataset([1], transform=xform_1)

        xform_1.set_random_state(123)
        out_1 = train_ds[0]
        self.assertAlmostEqual(out_1, 0.2045649)

        set_determinism(seed=123)
        train_loader = DataLoader(train_ds, num_workers=0)
        out_1 = next(iter(train_loader))
        self.assertAlmostEqual(out_1.cpu().item(), 0.84291356)

        if sys.platform != "win32":  # skip multi-worker tests on win32
            train_loader = DataLoader(train_ds, num_workers=1)
            out_1 = next(iter(train_loader))
            self.assertAlmostEqual(out_1.cpu().item(), 0.180814653)

            train_loader = DataLoader(train_ds, num_workers=2)
            out_1 = next(iter(train_loader))
            self.assertAlmostEqual(out_1.cpu().item(), 0.04293707)
        set_determinism(None)
コード例 #6
0
def run_training_test(root_dir,
                      train_x,
                      train_y,
                      val_x,
                      val_y,
                      device="cuda:0",
                      num_workers=10):

    monai.config.print_config()
    # define transforms for image and classification
    train_transforms = Compose([
        LoadPNG(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensor(),
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [LoadPNG(image_only=True),
         AddChannel(),
         ScaleIntensity(),
         ToTensor()])

    # create train, val data loaders
    train_ds = MedNISTDataset(train_x, train_y, train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=300,
                              shuffle=True,
                              num_workers=num_workers)

    val_ds = MedNISTDataset(val_x, val_y, val_transforms)
    val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers)

    model = densenet121(spatial_dims=2,
                        in_channels=1,
                        out_channels=len(np.unique(train_y))).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    epoch_num = 4
    val_interval = 1

    # start training validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                auc_metric = compute_roc_auc(y_pred,
                                             y,
                                             to_onehot_y=True,
                                             softmax=True)
                metric_values.append(auc_metric)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if auc_metric > best_metric:
                    best_metric = auc_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current AUC: {auc_metric:0.4f} "
                    f"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
    print(
        f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}"
    )
    return epoch_loss_values, best_metric, best_metric_epoch
コード例 #7
0
def run_training_test(root_dir, device="cuda:0", cachedataset=0):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            # resampling with align_corners=True or dtype=float64 will generate
            # slight different results between PyTorch 1.5 an 1.6
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32),
            ScaleIntensityd(keys="img"),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset == 2:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    elif cachedataset == 3:
        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    dice_metric = DiceMetric(include_background=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = val_post_tran(sliding_window_inference(val_images, roi_size, sw_batch_size, model))
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += not_nans.item()
                    metric_sum += value.item() * not_nans.item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
コード例 #8
0
    'rv_lv_ratio_gte_1',  # exam level
    "central_pe",
    "leftsided_pe",
    "rightsided_pe",
    "acute_and_chronic_pe",
    "chronic_pe"
]
out_dim = len(target_cols)
image_size = 100

val_transforms = Compose([
    ScaleIntensity(),
    Resize((image_size, image_size, image_size)),
    ToTensor()
])
val_transforms.set_random_state(seed=42)


def monai_preprocess(imgs512):
    imgs = imgs512[:, :, 43:-55, 43:-55]
    img_monai = imgs[int(imgs.shape[0] * 0.25):int(imgs.shape[0] * 0.75)]
    img_monai = np.transpose(img_monai, (1, 2, 3, 0))
    img_monai = apply_transform(val_transforms, img_monai)
    img_monai = np.expand_dims(img_monai, axis=0)
    img_monai = torch.from_numpy(img_monai).cuda()
    return img_monai


class MonaiModelTest():
    def __init__(self, monai_model_file):
        self.monai_model = monai.networks.nets.densenet.densenet121(
コード例 #9
0
ファイル: utils.py プロジェクト: ckbr0/RIS
def transform_and_copy(data, cahce_dir):
    copy_dir = os.path.join(cahce_dir, 'copied_images')
    if not os.path.exists(copy_dir):
        os.mkdir(copy_dir)
    copy_list_path = os.path.join(copy_dir, 'copied_images.npy')
    if not os.path.exists(copy_list_path):
        print("transforming and copying images...")
        imageLoader = LoadImage()
        to_copy_list = [x for x in data if int(x['_label']) == 1]
        mul = 1  #int(len(data)/len(to_copy_list) - 1)

        rand_x_flip = RandFlip(spatial_axis=0, prob=0.50)
        rand_y_flip = RandFlip(spatial_axis=1, prob=0.50)
        rand_z_flip = RandFlip(spatial_axis=2, prob=0.50)
        rand_affine = RandAffine(prob=1.0,
                                 rotate_range=(0, 0, np.pi / 10),
                                 shear_range=(0.12, 0.12, 0.0),
                                 translate_range=(0, 0, 0),
                                 scale_range=(0.12, 0.12, 0.0),
                                 padding_mode="zeros")
        rand_gaussian_noise = RandGaussianNoise(prob=0.5, mean=0.0, std=0.05)
        transform = Compose([
            AddChannel(),
            rand_x_flip,
            rand_y_flip,
            rand_z_flip,
            rand_affine,
            SqueezeDim(),
        ])
        copy_list = []
        n = len(to_copy_list)
        for i in range(len(to_copy_list)):
            print(f'Copying image {i+1}/{n}', end="\r")
            to_copy = to_copy_list[i]
            image_file = to_copy['image']
            _image_file = replace_suffix(image_file, '.nii.gz', '')
            label = to_copy['label']
            _label = to_copy['_label']
            image_data, _ = imageLoader(image_file)
            seg_file = to_copy['seg']
            seg_data, _ = nrrd.read(seg_file)

            for i in range(mul):
                rand_seed = np.random.randint(1e8)
                transform.set_random_state(seed=rand_seed)
                new_image_data = rand_gaussian_noise(
                    np.array(transform(image_data)))
                transform.set_random_state(seed=rand_seed)
                new_seg_data = np.array(transform(seg_data))
                #multi_slice_viewer(image_data, image_file)
                #multi_slice_viewer(seg_data, seg_file)
                #seg_image = MaskIntensity(seg_data)(image_data)
                #multi_slice_viewer(seg_image, seg_file)
                image_basename = os.path.basename(_image_file)
                seg_basename = image_basename + f'_seg_{i}.nrrd'
                image_basename = image_basename + f'_{i}.nii.gz'

                new_image_file = os.path.join(copy_dir, image_basename)
                write_nifti(new_image_data, new_image_file, resample=False)
                new_seg_file = os.path.join(copy_dir, seg_basename)
                nrrd.write(new_seg_file, new_seg_data)
                copy_list.append({
                    'image': new_image_file,
                    'seg': new_seg_file,
                    'label': label,
                    '_label': _label
                })

        np.save(copy_list_path, copy_list)
        print("done transforming and copying!")

    copy_list = np.load(copy_list_path, allow_pickle=True)
    return copy_list
コード例 #10
0
ファイル: utils.py プロジェクト: ckbr0/RIS
def large_image_splitter(data, cache_dir, num_splits, only_label_one=False):
    print("Splitting large images...")
    len_old = len(data)
    print("original data len:", len_old)
    split_images_dir = os.path.join(cache_dir, 'split_images')
    split_images = os.path.join(split_images_dir, 'split_images.npy')

    def _replace_in_data(split_images, num_splits):
        new_images = []
        for image in data:
            new_images.append(image)
            for s in split_images:
                source_image = s['source']
                if image['_label'] == 0 and only_label_one is True:
                    break
                if image['image'] == source_image:
                    #new_images.pop()
                    for i in range(min(num_splits, len(s["splits"]))):
                        new_images.append(s["splits"][i])
                    break
        return new_images

    if os.path.exists(split_images):
        new_images = np.load(split_images, allow_pickle=True)
        """for s in new_images:
            print("split image:", s["source"], end='\r')"""
        out_data = _replace_in_data(new_images, num_splits)
    else:
        if not os.path.exists(split_images_dir):
            os.mkdir(split_images_dir)
        new_images = []
        imageLoader = LoadImage()
        for image in data:
            image_data, _ = imageLoader(image["image"])
            seg_data, _ = nrrd.read(image['seg'])
            label = image['_label']
            z_len = image_data.shape[2]
            if z_len > 200:
                count = z_len // 80
                print("splitting image:",
                      image["image"],
                      f"into {count} parts",
                      "shape:",
                      image_data.shape,
                      end='\r')
                split_image_list = [
                    image_data[:, :, idz::count] for idz in range(count)
                ]
                split_seg_list = [
                    seg_data[:, :, idz::count] for idz in range(count)
                ]
                new_image = {'source': image["image"], 'splits': []}
                for i in range(count):
                    image_file = os.path.basename(
                        replace_suffix(image["image"], '.nii.gz', ''))
                    image_file = os.path.join(split_images_dir,
                                              image_file + f'_{i}.nii.gz')
                    seg_file = os.path.basename(
                        replace_suffix(image["seg"], '.nrrd', ''))
                    seg_file = os.path.join(split_images_dir,
                                            seg_file + f'_seg_{i}.nrrd')
                    split_image = np.array(split_image_list[i])
                    split_seg = np.array(split_seg_list[i], dtype=np.uint8)

                    rand_affine = RandAffine(prob=1.0,
                                             rotate_range=(0, 0, np.pi / 16),
                                             shear_range=(0.07, 0.07, 0.0),
                                             translate_range=(0, 0, 0),
                                             scale_range=(0.07, 0.07, 0.0),
                                             padding_mode="zeros")
                    transform = Compose([
                        AddChannel(),
                        rand_affine,
                        SqueezeDim(),
                    ])
                    rand_seed = np.random.randint(1e8)
                    transform.set_random_state(seed=rand_seed)
                    split_image = transform(split_image).detach().cpu().numpy()
                    transform.set_random_state(seed=rand_seed)
                    split_seg = transform(split_seg).detach().cpu().numpy()

                    write_nifti(split_image, image_file, resample=False)
                    nrrd.write(seg_file, split_seg)
                    new_image['splits'].append({
                        'image': image_file,
                        'label': image['label'],
                        '_label': image['_label'],
                        'seg': seg_file,
                        'w': False
                    })
                new_images.append(new_image)
        np.save(split_images, new_images)
        out_data = _replace_in_data(new_images, num_splits)

    print("new data len:", len(out_data))
    return out_data
コード例 #11
0
def run_training_test(root_dir,
                      device=torch.device("cuda:0"),
                      cachedataset=False):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, 'img*.nii.gz')))
    segs = sorted(glob(os.path.join(root_dir, 'seg*.nii.gz')))
    train_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        'img': img,
        'seg': seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        RandCropByPosNegLabeld(keys=['img', 'seg'],
                               label_key='seg',
                               size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=['img', 'seg'], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=['img', 'seg'])
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AsChannelFirstd(keys=['img', 'seg'], channel_dim=-1),
        ScaleIntensityd(keys=['img', 'seg']),
        ToTensord(keys=['img', 'seg'])
    ])

    # create a training data loader
    if cachedataset:
        train_ds = monai.data.CacheDataset(data=train_files,
                                           transform=train_transforms,
                                           cache_rate=0.8)
    else:
        train_ds = monai.data.Dataset(data=train_files,
                                      transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, 'runs'))
    model_filename = os.path.join(root_dir, 'best_metric_model.pth')
    for epoch in range(6):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch + 1, 6))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data['img'].to(
                device), batch_data['seg'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print("%d/%d, train_loss:%0.4f" % (step, epoch_len, loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data['img'].to(
                        device), val_data['seg'].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs,
                                             y=val_labels,
                                             include_background=True,
                                             to_onehot_y=False,
                                             add_sigmoid=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print('saved new best metric model')
                print(
                    "current epoch %d current mean dice: %0.4f best mean dice: %0.4f at epoch %d"
                    % (epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='image')
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='label')
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag='output')
    print('train completed, best_metric: %0.4f  at epoch: %d' %
          (best_metric, best_metric_epoch))
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch