def test_seg_no_basedir(self):
     with tempfile.TemporaryDirectory() as tempdir:
         test_data = {
             "name": "Spleen",
             "description": "Spleen Segmentation",
             "labels": {"0": "background", "1": "spleen"},
             "training": [
                 {
                     "image": os.path.join(tempdir, "spleen_19.nii.gz"),
                     "label": os.path.join(tempdir, "spleen_19.nii.gz"),
                 },
                 {
                     "image": os.path.join(tempdir, "spleen_31.nii.gz"),
                     "label": os.path.join(tempdir, "spleen_31.nii.gz"),
                 },
             ],
             "test": [os.path.join(tempdir, "spleen_15.nii.gz"), os.path.join(tempdir, "spleen_23.nii.gz")],
         }
         json_str = json.dumps(test_data)
         file_path = os.path.join(tempdir, "test_data.json")
         with open(file_path, "w") as json_file:
             json_file.write(json_str)
         result = load_decathlon_datalist(file_path, True, "training", None)
         self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
         self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
         result = load_decathlon_datalist(file_path, True, "test", None)
         self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))
Beispiel #2
0
 def test_cls_values(self):
     with tempfile.TemporaryDirectory() as tempdir:
         test_data = {
             "name":
             "ChestXRay",
             "description":
             "Chest X-ray classification",
             "labels": {
                 "0": "background",
                 "1": "chest"
             },
             "training": [{
                 "image": "chest_19.nii.gz",
                 "label": 0
             }, {
                 "image": "chest_31.nii.gz",
                 "label": 1
             }],
             "test": ["chest_15.nii.gz", "chest_23.nii.gz"],
         }
         json_str = json.dumps(test_data)
         file_path = os.path.join(tempdir, "test_data.json")
         with open(file_path, "w") as json_file:
             json_file.write(json_str)
         result = load_decathlon_datalist(file_path, False, "training",
                                          tempdir)
         self.assertEqual(result[0]["image"],
                          os.path.join(tempdir, "chest_19.nii.gz"))
         self.assertEqual(result[0]["label"], 0)
    def test_content(self):
        with tempfile.TemporaryDirectory() as tempdir:
            datalist = []
            for i in range(5):
                image = os.path.join(tempdir, f"test_image{i}.nii.gz")
                label = os.path.join(tempdir, f"test_label{i}.nii.gz")
                Path(image).touch()
                Path(label).touch()
                datalist.append({"image": image, "label": label})

            filename = os.path.join(tempdir, "test_datalist.json")
            result = create_cross_validation_datalist(
                datalist=datalist,
                nfolds=5,
                train_folds=[0, 1, 2, 3],
                val_folds=4,
                train_key="test_train",
                val_key="test_val",
                filename=Path(filename),
                shuffle=True,
                seed=123,
                check_missing=True,
                keys=["image", "label"],
                root_dir=None,
                allow_missing_keys=False,
                raise_error=True,
            )

            loaded = load_decathlon_datalist(filename, data_list_key="test_train")
            for r, l in zip(result["test_train"], loaded):
                self.assertEqual(r["image"], l["image"])
                self.assertEqual(r["label"], l["label"])
Beispiel #4
0
 def _generate_data_list(self, dataset_dir: PathLike) -> List[Dict]:
     # the types of the item in data list should be compatible with the dataloader
     dataset_dir = Path(dataset_dir)
     section = "training" if self.section in ["training", "validation"
                                              ] else "test"
     datalist = load_decathlon_datalist(dataset_dir / "dataset.json", True,
                                        section)
     return self._split_datalist(datalist)
def get_task_params(args):
    """
    This function is used to achieve the spacings of decathlon dataset.
    In addition, for CT images (task 03, 06, 07, 08, 09 and 10), this function
    also prints the mean and std values (used for normalization), and the min (0.5 percentile)
    and max(99.5 percentile) values (used for clip).

    """
    task_id = args.task_id
    root_dir = args.root_dir
    datalist_path = args.datalist_path
    dataset_path = os.path.join(root_dir, task_name[task_id])
    datalist_name = "dataset_task{}.json".format(task_id)

    # get all training data
    datalist = load_decathlon_datalist(
        os.path.join(datalist_path, datalist_name), True, "training", dataset_path
    )

    # get modality info.
    properties = load_decathlon_properties(
        os.path.join(datalist_path, datalist_name), "modality"
    )

    dataset = Dataset(
        data=datalist,
        transform=LoadImaged(keys=["image", "label"]),
    )

    calculator = DatasetSummary(dataset, num_workers=4)
    target_spacing = calculator.get_target_spacing()
    print("spacing: ", target_spacing)
    if properties["modality"]["0"] == "CT":
        print("CT input, calculate statistics:")
        calculator.calculate_statistics()
        print("mean: ", calculator.data_mean, " std: ", calculator.data_std)
        calculator.calculate_percentiles(
            sampling_flag=True, interval=10, min_percentile=0.5, max_percentile=99.5
        )
        print(
            "min: ",
            calculator.data_min_percentile,
            " max: ",
            calculator.data_max_percentile,
        )
    else:
        print("non CT input, skip calculating.")
Beispiel #6
0
 def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
     section = "training" if self.section in ["training", "validation"
                                              ] else "test"
     datalist = load_decathlon_datalist(
         os.path.join(dataset_dir, "dataset.json"), True, section)
     if section == "test":
         return datalist
     else:
         data = list()
         for i in datalist:
             self.randomize()
             if self.section == "training":
                 if self.rann < self.val_frac:
                     continue
             else:
                 if self.rann >= self.val_frac:
                     continue
             data.append(i)
         return data
Beispiel #7
0
    def test_additional_items(self):
        with tempfile.TemporaryDirectory() as tempdir:
            with open(os.path.join(tempdir, "mask31.txt"), "w") as f:
                f.write("spleen31 mask")

            test_data = {
                "name":
                "Spleen",
                "description":
                "Spleen Segmentation",
                "labels": {
                    "0": "background",
                    "1": "spleen"
                },
                "training": [
                    {
                        "image": "spleen_19.nii.gz",
                        "label": "spleen_19.nii.gz",
                        "mask": "spleen mask"
                    },
                    {
                        "image": "spleen_31.nii.gz",
                        "label": "spleen_31.nii.gz",
                        "mask": "mask31.txt"
                    },
                ],
                "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"],
            }
            json_str = json.dumps(test_data)
            file_path = os.path.join(tempdir, "test_data.json")
            with open(file_path, "w") as json_file:
                json_file.write(json_str)
            result = load_decathlon_datalist(file_path, True, "training",
                                             Path(tempdir))
            self.assertEqual(result[0]["image"],
                             os.path.join(tempdir, "spleen_19.nii.gz"))
            self.assertEqual(result[0]["label"],
                             os.path.join(tempdir, "spleen_19.nii.gz"))
            self.assertEqual(result[1]["mask"],
                             os.path.join(tempdir, "mask31.txt"))
            self.assertEqual(result[0]["mask"], "spleen mask")
Beispiel #8
0
 def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
     section = "training" if self.section in ["training", "validation"
                                              ] else "test"
     datalist = load_decathlon_datalist(
         os.path.join(dataset_dir, "dataset.json"), True, section)
     return self._split_datalist(datalist)
Beispiel #9
0
def main():
    print_config()

    # Define paths for running the script
    data_dir = os.path.normpath('/to/be/defined')
    json_path = os.path.normpath('/to/be/defined')
    logdir = os.path.normpath('/to/be/defined')

    # If use_pretrained is set to 0, ViT weights will not be loaded and random initialization is used
    use_pretrained = 1
    pretrained_path = os.path.normpath('/to/be/defined')

    # Training Hyper-parameters
    lr = 1e-4
    max_iterations = 30000
    eval_num = 100

    if os.path.exists(logdir) == False:
        os.mkdir(logdir)

    # Training & Validation Transform chain
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["image"],
                             a_min=-175,
                             a_max=250,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ])

    datalist = load_decathlon_datalist(base_dir=data_dir,
                                       data_list_file_path=json_path,
                                       is_segmentation=True,
                                       data_list_key="training")

    val_files = load_decathlon_datalist(base_dir=data_dir,
                                        data_list_file_path=json_path,
                                        is_segmentation=True,
                                        data_list_key="validation")
    train_ds = CacheDataset(
        data=datalist,
        transform=train_transforms,
        cache_num=24,
        cache_rate=1.0,
        num_workers=4,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=1,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    val_ds = CacheDataset(data=val_files,
                          transform=val_transforms,
                          cache_num=6,
                          cache_rate=1.0,
                          num_workers=4)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    case_num = 0
    img = val_ds[case_num]["image"]
    label = val_ds[case_num]["label"]
    img_shape = img.shape
    label_shape = label.shape
    print(f"image shape: {img_shape}, label shape: {label_shape}")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNETR(
        in_channels=1,
        out_channels=14,
        img_size=(96, 96, 96),
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="conv",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    )

    # Load ViT backbone weights into UNETR
    if use_pretrained == 1:
        print('Loading Weights from the Path {}'.format(pretrained_path))
        vit_dict = torch.load(pretrained_path)
        vit_weights = vit_dict['state_dict']

        #  Delete the following variable names conv3d_transpose.weight, conv3d_transpose.bias,
        #  conv3d_transpose_1.weight, conv3d_transpose_1.bias as they were used to match dimensions
        #  while pretraining with ViTAutoEnc and are not a part of ViT backbone (this is used in UNETR)
        vit_weights.pop('conv3d_transpose_1.bias')
        vit_weights.pop('conv3d_transpose_1.weight')
        vit_weights.pop('conv3d_transpose.bias')
        vit_weights.pop('conv3d_transpose.weight')

        model.vit.load_state_dict(vit_weights)
        print('Pretrained Weights Succesfully Loaded !')

    elif use_pretrained == 0:
        print(
            'No weights were loaded, all weights being used are randomly initialized!'
        )

    model.to(device)

    loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    post_label = AsDiscrete(to_onehot=14)
    post_pred = AsDiscrete(argmax=True, to_onehot=14)
    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []

    def validation(epoch_iterator_val):
        model.eval()
        dice_vals = list()

        with torch.no_grad():
            for step, batch in enumerate(epoch_iterator_val):
                val_inputs, val_labels = (batch["image"].cuda(),
                                          batch["label"].cuda())
                val_outputs = sliding_window_inference(val_inputs,
                                                       (96, 96, 96), 4, model)
                val_labels_list = decollate_batch(val_labels)
                val_labels_convert = [
                    post_label(val_label_tensor)
                    for val_label_tensor in val_labels_list
                ]
                val_outputs_list = decollate_batch(val_outputs)
                val_output_convert = [
                    post_pred(val_pred_tensor)
                    for val_pred_tensor in val_outputs_list
                ]
                dice_metric(y_pred=val_output_convert, y=val_labels_convert)
                dice = dice_metric.aggregate().item()
                dice_vals.append(dice)
                epoch_iterator_val.set_description(
                    "Validate (%d / %d Steps) (dice=%2.5f)" %
                    (global_step, 10.0, dice))

            dice_metric.reset()

        mean_dice_val = np.mean(dice_vals)
        return mean_dice_val

    def train(global_step, train_loader, dice_val_best, global_step_best):
        model.train()
        epoch_loss = 0
        step = 0
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              dynamic_ncols=True)
        for step, batch in enumerate(epoch_iterator):
            step += 1
            x, y = (batch["image"].cuda(), batch["label"].cuda())
            logit_map = model(x)
            loss = loss_function(logit_map, y)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            epoch_iterator.set_description(
                "Training (%d / %d Steps) (loss=%2.5f)" %
                (global_step, max_iterations, loss))

            if (global_step % eval_num == 0
                    and global_step != 0) or global_step == max_iterations:
                epoch_iterator_val = tqdm(
                    val_loader,
                    desc="Validate (X / X Steps) (dice=X.X)",
                    dynamic_ncols=True)
                dice_val = validation(epoch_iterator_val)

                epoch_loss /= step
                epoch_loss_values.append(epoch_loss)
                metric_values.append(dice_val)
                if dice_val > dice_val_best:
                    dice_val_best = dice_val
                    global_step_best = global_step
                    torch.save(model.state_dict(),
                               os.path.join(logdir, "best_metric_model.pth"))
                    print(
                        "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))
                else:
                    print(
                        "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}"
                        .format(dice_val_best, dice_val))

                plt.figure(1, (12, 6))
                plt.subplot(1, 2, 1)
                plt.title("Iteration Average Loss")
                x = [eval_num * (i + 1) for i in range(len(epoch_loss_values))]
                y = epoch_loss_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.subplot(1, 2, 2)
                plt.title("Val Mean Dice")
                x = [eval_num * (i + 1) for i in range(len(metric_values))]
                y = metric_values
                plt.xlabel("Iteration")
                plt.plot(x, y)
                plt.grid()
                plt.savefig(
                    os.path.join(logdir, 'btcv_finetune_quick_update.png'))
                plt.clf()
                plt.close(1)

            global_step += 1
        return global_step, dice_val_best, global_step_best

    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(
            global_step, train_loader, dice_val_best, global_step_best)
    model.load_state_dict(
        torch.load(os.path.join(logdir, "best_metric_model.pth")))

    print(f"train completed, best_metric: {dice_val_best:.4f} "
          f"at iteration: {global_step_best}")
Beispiel #10
0
    def configure(self):
        self.set_device()
        network = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm=Norm.BATCH,
        ).to(self.device)
        if self.multi_gpu:
            network = DistributedDataParallel(
                module=network,
                device_ids=[self.device],
                find_unused_parameters=False,
            )

        train_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            Spacingd(keys=("image", "label"),
                     pixdim=[1.0, 1.0, 1.0],
                     mode=["bilinear", "nearest"]),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            RandCropByPosNegLabeld(
                keys=("image", "label"),
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=("image", "label")),
        ])
        train_datalist = load_decathlon_datalist(self.data_list_file_path,
                                                 True, "training")
        if self.multi_gpu:
            train_datalist = partition_dataset(
                data=train_datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]
        train_ds = CacheDataset(
            data=train_datalist,
            transform=train_transforms,
            cache_num=32,
            cache_rate=1.0,
            num_workers=4,
        )
        train_data_loader = DataLoader(
            train_ds,
            batch_size=2,
            shuffle=True,
            num_workers=4,
        )
        val_transforms = Compose([
            LoadImaged(keys=("image", "label")),
            EnsureChannelFirstd(keys=("image", "label")),
            ScaleIntensityRanged(
                keys="image",
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=("image", "label"), source_key="image"),
            ToTensord(keys=("image", "label")),
        ])

        val_datalist = load_decathlon_datalist(self.data_list_file_path, True,
                                               "validation")
        val_ds = CacheDataset(val_datalist, val_transforms, 9, 0.0, 4)
        val_data_loader = DataLoader(
            val_ds,
            batch_size=1,
            shuffle=False,
            num_workers=4,
        )
        post_transform = Compose([
            Activationsd(keys="pred", softmax=True),
            AsDiscreted(
                keys=["pred", "label"],
                argmax=[True, False],
                to_onehot=True,
                n_classes=2,
            ),
        ])
        # metric
        key_val_metric = {
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
                device=self.device,
            )
        }
        val_handlers = [
            StatsHandler(output_transform=lambda x: None),
            CheckpointSaver(
                save_dir=self.ckpt_dir,
                save_dict={"model": network},
                save_key_metric=True,
            ),
            TensorBoardStatsHandler(log_dir=self.ckpt_dir,
                                    output_transform=lambda x: None),
        ]
        self.eval_engine = SupervisedEvaluator(
            device=self.device,
            val_data_loader=val_data_loader,
            network=network,
            inferer=SlidingWindowInferer(
                roi_size=[160, 160, 160],
                sw_batch_size=4,
                overlap=0.5,
            ),
            post_transform=post_transform,
            key_val_metric=key_val_metric,
            val_handlers=val_handlers,
            amp=self.amp,
        )

        optimizer = torch.optim.Adam(network.parameters(), self.learning_rate)
        loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=5000,
                                                       gamma=0.1)
        train_handlers = [
            LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
            ValidationHandler(validator=self.eval_engine,
                              interval=self.val_interval,
                              epoch_level=True),
            StatsHandler(tag_name="train_loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(
                log_dir=self.ckpt_dir,
                tag_name="train_loss",
                output_transform=lambda x: x["loss"],
            ),
        ]

        self.train_engine = SupervisedTrainer(
            device=self.device,
            max_epochs=self.max_epochs,
            train_data_loader=train_data_loader,
            network=network,
            optimizer=optimizer,
            loss_function=loss_function,
            inferer=SimpleInferer(),
            post_transform=post_transform,
            key_train_metric=None,
            train_handlers=train_handlers,
            amp=self.amp,
        )

        if self.local_rank > 0:
            self.train_engine.logger.setLevel(logging.WARNING)
            self.eval_engine.logger.setLevel(logging.WARNING)
def main_worker(gpu, args):

    args.gpu = gpu

    if args.distributed:
        args.rank = args.rank * torch.cuda.device_count() + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    print(args.rank, " gpu", args.gpu)

    torch.cuda.set_device(
        args.gpu
    )  # use this default device (same as args.device if not distributed)
    torch.backends.cudnn.benchmark = True

    if args.rank == 0:
        print("Batch size is:", args.batch_size, "epochs", args.epochs)

    #############
    # Create MONAI dataset
    training_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="training",
        base_dir=args.data_root,
    )
    validation_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="validation",
        base_dir=args.data_root,
    )

    if args.quick:  # for debugging on a small subset
        training_list = training_list[:16]
        validation_list = validation_list[:16]

    train_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=args.tile_count,
            tile_size=args.tile_size,
            random_offset=True,
            background_val=255,
            return_list_of_dicts=True,
        ),
        RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
        RandRotate90d(keys=["image"], prob=0.5),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    valid_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=None,
            tile_size=args.tile_size,
            random_offset=False,
            background_val=255,
            return_list_of_dicts=True,
        ),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    dataset_train = Dataset(data=training_list, transform=train_transform)
    dataset_valid = Dataset(data=validation_list, transform=valid_transform)

    train_sampler = DistributedSampler(
        dataset_train) if args.distributed else None
    val_sampler = DistributedSampler(
        dataset_valid, shuffle=False) if args.distributed else None

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=train_sampler,
        collate_fn=list_data_collate,
    )
    valid_loader = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=1,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=val_sampler,
        collate_fn=list_data_collate,
    )

    if args.rank == 0:
        print("Dataset training:", len(dataset_train), "validation:",
              len(dataset_valid))

    model = milmodel.MILModel(num_classes=args.num_classes,
                              pretrained=True,
                              mil_mode=args.mil_mode)

    best_acc = 0
    start_epoch = 0
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
        if "epoch" in checkpoint:
            start_epoch = checkpoint["epoch"]
        if "best_acc" in checkpoint:
            best_acc = checkpoint["best_acc"]
        print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(
            args.checkpoint, start_epoch, best_acc))

    model.cuda(args.gpu)

    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], output_device=args.gpu)

    if args.validate:
        # if we only want to validate existing checkpoint
        epoch_time = time.time()
        val_loss, val_acc, qwk = val_epoch(model,
                                           valid_loader,
                                           epoch=0,
                                           args=args,
                                           max_tiles=args.tile_count)
        if args.rank == 0:
            print(
                "Final validation loss: {:.4f}".format(val_loss),
                "acc: {:.4f}".format(val_acc),
                "qwk: {:.4f}".format(qwk),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        exit(0)

    params = model.parameters()

    if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
        m = model if not args.distributed else model.module
        params = [
            {
                "params":
                list(m.attention.parameters()) + list(m.myfc.parameters()) +
                list(m.net.parameters())
            },
            {
                "params": list(m.transformer.parameters()),
                "lr": 6e-6,
                "weight_decay": 0.1
            },
        ]

    optimizer = torch.optim.AdamW(params,
                                  lr=args.optim_lr,
                                  weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs,
                                                           eta_min=0)

    if args.logdir is not None and args.rank == 0:
        writer = SummaryWriter(log_dir=args.logdir)
        if args.rank == 0:
            print("Writing Tensorboard logs to ", writer.log_dir)
    else:
        writer = None

    ###RUN TRAINING
    n_epochs = args.epochs
    val_acc_max = 0.0

    scaler = None
    if args.amp:  # new native amp
        scaler = GradScaler()

    for epoch in range(start_epoch, n_epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)
            torch.distributed.barrier()

        print(args.rank, time.ctime(), "Epoch:", epoch)

        epoch_time = time.time()
        train_loss, train_acc = train_epoch(model,
                                            train_loader,
                                            optimizer,
                                            scaler=scaler,
                                            epoch=epoch,
                                            args=args)

        if args.rank == 0:
            print(
                "Final training  {}/{}".format(epoch, n_epochs - 1),
                "loss: {:.4f}".format(train_loss),
                "acc: {:.4f}".format(train_acc),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        if args.rank == 0 and writer is not None:
            writer.add_scalar("train_loss", train_loss, epoch)
            writer.add_scalar("train_acc", train_acc, epoch)

        if args.distributed:
            torch.distributed.barrier()

        b_new_best = False
        val_acc = 0
        if (epoch + 1) % args.val_every == 0:

            epoch_time = time.time()
            val_loss, val_acc, qwk = val_epoch(model,
                                               valid_loader,
                                               epoch=epoch,
                                               args=args,
                                               max_tiles=args.tile_count)
            if args.rank == 0:
                print(
                    "Final validation  {}/{}".format(epoch, n_epochs - 1),
                    "loss: {:.4f}".format(val_loss),
                    "acc: {:.4f}".format(val_acc),
                    "qwk: {:.4f}".format(qwk),
                    "time {:.2f}s".format(time.time() - epoch_time),
                )
                if writer is not None:
                    writer.add_scalar("val_loss", val_loss, epoch)
                    writer.add_scalar("val_acc", val_acc, epoch)
                    writer.add_scalar("val_qwk", qwk, epoch)

                val_acc = qwk

                if val_acc > val_acc_max:
                    print("qwk ({:.6f} --> {:.6f})".format(
                        val_acc_max, val_acc))
                    val_acc_max = val_acc
                    b_new_best = True

        if args.rank == 0 and args.logdir is not None:
            save_checkpoint(model,
                            epoch,
                            args,
                            best_acc=val_acc,
                            filename="model_final.pt")
            if b_new_best:
                print("Copying to model.pt new best model!!!!")
                shutil.copyfile(os.path.join(args.logdir, "model_final.pt"),
                                os.path.join(args.logdir, "model.pt"))

        scheduler.step()

    print("ALL DONE")
Beispiel #12
0
def main(cfg):
    # -------------------------------------------------------------------------
    # Configs
    # -------------------------------------------------------------------------
    # Create log/model dir
    log_dir = create_log_dir(cfg)

    # Set the logger
    logging.basicConfig(
        format="%(asctime)s %(levelname)2s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    log_name = os.path.join(log_dir, "logs.txt")
    logger = logging.getLogger()
    fh = logging.FileHandler(log_name)
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    # Set TensorBoard summary writer
    writer = SummaryWriter(log_dir)

    # Save configs
    logging.info(json.dumps(cfg))
    with open(os.path.join(log_dir, "config.json"), "w") as fp:
        json.dump(cfg, fp, indent=4)

    # Set device cuda/cpu
    device = set_device(cfg)

    # Set cudnn benchmark/deterministic
    if cfg["benchmark"]:
        torch.backends.cudnn.benchmark = True
    else:
        set_determinism(seed=0)
    # -------------------------------------------------------------------------
    # Transforms and Datasets
    # -------------------------------------------------------------------------
    # Pre-processing
    preprocess_cpu_train = None
    preprocess_gpu_train = None
    preprocess_cpu_valid = None
    preprocess_gpu_valid = None
    if cfg["backend"] == "cucim":
        preprocess_cpu_train = Compose([ToTensorD(keys="label")])
        preprocess_gpu_train = Compose([
            Range()(ToCupy()),
            Range("ColorJitter")(RandCuCIM(name="color_jitter",
                                           brightness=64.0 / 255.0,
                                           contrast=0.75,
                                           saturation=0.25,
                                           hue=0.04)),
            Range("RandomFlip")(RandCuCIM(name="image_flip",
                                          apply_prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandCuCIM(name="rand_image_rotate_90",
                                              prob=cfg["prob"],
                                              max_k=3,
                                              spatial_axis=(-2, -1))),
            Range()(CastToType(dtype=np.float32)),
            Range("RandomZoom")(RandCuCIM(name="rand_zoom",
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(CuCIM(name="scale_intensity_range",
                                          a_min=0.0,
                                          a_max=255.0,
                                          b_min=-1.0,
                                          b_max=1.0)),
            Range()(ToTensor(device=device)),
        ])
        preprocess_cpu_valid = Compose([ToTensorD(keys="label")])
        preprocess_gpu_valid = Compose([
            Range("ValidToCupyAndCast")(ToCupy(dtype=np.float32)),
            Range("ValidScaleIntensity")(CuCIM(name="scale_intensity_range",
                                               a_min=0.0,
                                               a_max=255.0,
                                               b_min=-1.0,
                                               b_max=1.0)),
            Range("ValidToTensor")(ToTensor(device=device)),
        ])
    elif cfg["backend"] == "numpy":
        preprocess_cpu_train = Compose([
            Range()(ToTensorD(keys=("image", "label"))),
            Range("ColorJitter")(TorchVisionD(
                keys="image",
                name="ColorJitter",
                brightness=64.0 / 255.0,
                contrast=0.75,
                saturation=0.25,
                hue=0.04,
            )),
            Range()(ToNumpyD(keys="image")),
            Range("RandomFlip")(RandFlipD(keys="image",
                                          prob=cfg["prob"],
                                          spatial_axis=-1)),
            Range("RandomRotate90")(RandRotate90D(keys="image",
                                                  prob=cfg["prob"])),
            Range()(CastToTypeD(keys="image", dtype=np.float32)),
            Range("RandomZoom")(RandZoomD(keys="image",
                                          prob=cfg["prob"],
                                          min_zoom=0.9,
                                          max_zoom=1.1)),
            Range("ScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                         a_min=0.0,
                                                         a_max=255.0,
                                                         b_min=-1.0,
                                                         b_max=1.0)),
            Range()(ToTensorD(keys="image")),
        ])
        preprocess_cpu_valid = Compose([
            Range("ValidCastType")(CastToTypeD(keys="image",
                                               dtype=np.float32)),
            Range("ValidScaleIntensity")(ScaleIntensityRangeD(keys="image",
                                                              a_min=0.0,
                                                              a_max=255.0,
                                                              b_min=-1.0,
                                                              b_max=1.0)),
            Range("ValidToTensor")(ToTensorD(keys=("image", "label"))),
        ])
    else:
        raise ValueError(
            f"Backend should be either numpy or cucim! ['{cfg['backend']}' is provided.]"
        )

    # Post-processing
    postprocess = Compose([
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5),
    ])

    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )
    train_dataset = PatchWSIDataset(
        data=train_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_train,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        data=valid_json_info_list,
        region_size=cfg["region_size"],
        grid_shape=cfg["grid_shape"],
        patch_size=cfg["patch_size"],
        transform=preprocess_cpu_valid,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=cfg["pin"])

    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("First sample is None!")
    for d in ["image", "label"]:
        logging.info(f"[{d}] \n"
                     f"  {d} shape: {first_sample[d].shape}\n"
                     f"  {d} type:  {type(first_sample[d])}\n"
                     f"  {d} dtype: {first_sample[d].dtype}")
    logging.info(f"Batch size: {cfg['batch_size']}")
    logging.info(f"[Training] number of batches: {len(train_dataloader)}")
    logging.info(f"[Validation] number of batches: {len(valid_dataloader)}")
    # -------------------------------------------------------------------------
    # Deep Learning Model and Configurations
    # -------------------------------------------------------------------------
    # Initialize model
    model = TorchVisionFCModel("resnet18",
                               n_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # Loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # Optimizer
    if cfg["novograd"] is True:
        optimizer = Novograd(model.parameters(), lr=cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    cfg["amp"] = cfg["amp"] and monai.utils.get_torch_version_tuple() >= (1, 6)
    if cfg["amp"] is True:
        scaler = GradScaler()
    else:
        scaler = None

    # Learning rate scheduler
    if cfg["cos"] is True:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=cfg["n_epochs"])
    else:
        scheduler = None

    # -------------------------------------------------------------------------
    # Training/Evaluating
    # -------------------------------------------------------------------------
    train_counter = {"n_epochs": cfg["n_epochs"], "epoch": 1, "step": 1}

    total_valid_time, total_train_time = 0.0, 0.0
    t_start = time.perf_counter()
    metric_summary = {"loss": np.Inf, "accuracy": 0, "best_epoch": 1}
    # Training/Validation Loop
    for _ in range(cfg["n_epochs"]):
        t_epoch = time.perf_counter()
        logging.info(
            f"[Training] learning rate: {optimizer.param_groups[0]['lr']}")

        # Training
        with Range("Training Epoch"):
            train_counter = training(
                train_counter,
                model,
                loss_func,
                optimizer,
                scaler,
                cfg["amp"],
                train_dataloader,
                preprocess_gpu_train,
                postprocess,
                device,
                writer,
                cfg["print_step"],
            )
        if scheduler is not None:
            scheduler.step()
        if cfg["save"]:
            torch.save(
                model.state_dict(),
                os.path.join(log_dir,
                             f"model_epoch_{train_counter['epoch']}.pt"))
        t_train = time.perf_counter()
        train_time = t_train - t_epoch
        total_train_time += train_time

        # Validation
        if cfg["validate"]:
            with Range("Validation"):
                valid_loss, valid_acc = validation(
                    model,
                    loss_func,
                    cfg["amp"],
                    valid_dataloader,
                    preprocess_gpu_valid,
                    postprocess,
                    device,
                    cfg["print_step"],
                )
            t_valid = time.perf_counter()
            valid_time = t_valid - t_train
            total_valid_time += valid_time
            if valid_loss < metric_summary["loss"]:
                metric_summary["loss"] = min(valid_loss,
                                             metric_summary["loss"])
                metric_summary["accuracy"] = max(valid_acc,
                                                 metric_summary["accuracy"])
                metric_summary["best_epoch"] = train_counter["epoch"]
            writer.add_scalar("valid/loss", valid_loss, train_counter["epoch"])
            writer.add_scalar("valid/accuracy", valid_acc,
                              train_counter["epoch"])

            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] loss: {valid_loss:.3f}, accuracy: {valid_acc:.2f}, "
                f"time: {t_valid - t_epoch:.1f}s (train: {train_time:.1f}s, valid: {valid_time:.1f}s)"
            )
        else:
            logging.info(
                f"[Epoch: {train_counter['epoch']}/{cfg['n_epochs']}] Train time: {train_time:.1f}s"
            )
        writer.flush()
    t_end = time.perf_counter()

    # Save final metrics
    metric_summary["train_time_per_epoch"] = total_train_time / cfg["n_epochs"]
    metric_summary["total_time"] = t_end - t_start
    writer.add_hparams(hparam_dict=cfg,
                       metric_dict=metric_summary,
                       run_name=log_dir)
    writer.close()
    logging.info(f"Metric Summary: {metric_summary}")

    # Save the best and final model
    if cfg["validate"] is True:
        copyfile(
            os.path.join(log_dir,
                         f"model_epoch_{metric_summary['best_epoch']}.pth"),
            os.path.join(log_dir, "model_best.pth"),
        )
        copyfile(
            os.path.join(log_dir, f"model_epoch_{cfg['n_epochs']}.pth"),
            os.path.join(log_dir, "model_final.pth"),
        )

    # Final prints
    logging.info(
        f"[Completed] {train_counter['epoch']} epochs -- time: {t_end - t_start:.1f}s "
        f"(training: {total_train_time:.1f}s, validation: {total_valid_time:.1f}s)",
    )
    logging.info(f"Logs and model was saved at: {log_dir}")
Beispiel #13
0
def train(cfg):
    log_dir = create_log_dir(cfg)
    device = set_device(cfg)
    # --------------------------------------------------------------------------
    # Data Loading and Preprocessing
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # Build MONAI preprocessing
    train_preprocess = Compose([
        ToTensorD(keys="image"),
        TorchVisionD(keys="image",
                     name="ColorJitter",
                     brightness=64.0 / 255.0,
                     contrast=0.75,
                     saturation=0.25,
                     hue=0.04),
        ToNumpyD(keys="image"),
        RandFlipD(keys="image", prob=0.5),
        RandRotate90D(keys="image", prob=0.5),
        CastToTypeD(keys="image", dtype=np.float32),
        RandZoomD(keys="image", prob=0.5, min_zoom=0.9, max_zoom=1.1),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    valid_preprocess = Compose([
        CastToTypeD(keys="image", dtype=np.float32),
        ScaleIntensityRangeD(keys="image",
                             a_min=0.0,
                             a_max=255.0,
                             b_min=-1.0,
                             b_max=1.0),
        ToTensorD(keys=("image", "label")),
    ])
    # __________________________________________________________________________
    # Create MONAI dataset
    train_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="training",
        base_dir=cfg["data_root"],
    )
    valid_json_info_list = load_decathlon_datalist(
        data_list_file_path=cfg["dataset_json"],
        data_list_key="validation",
        base_dir=cfg["data_root"],
    )

    train_dataset = PatchWSIDataset(
        train_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        train_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )
    valid_dataset = PatchWSIDataset(
        valid_json_info_list,
        cfg["region_size"],
        cfg["grid_shape"],
        cfg["patch_size"],
        valid_preprocess,
        image_reader_name="openslide" if cfg["use_openslide"] else "cuCIM",
    )

    # __________________________________________________________________________
    # DataLoaders
    train_dataloader = DataLoader(train_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset,
                                  num_workers=cfg["num_workers"],
                                  batch_size=cfg["batch_size"],
                                  pin_memory=True)

    # __________________________________________________________________________
    # Get sample batch and some info
    first_sample = first(train_dataloader)
    if first_sample is None:
        raise ValueError("Fist sample is None!")

    print("image: ")
    print("    shape", first_sample["image"].shape)
    print("    type: ", type(first_sample["image"]))
    print("    dtype: ", first_sample["image"].dtype)
    print("labels: ")
    print("    shape", first_sample["label"].shape)
    print("    type: ", type(first_sample["label"]))
    print("    dtype: ", first_sample["label"].dtype)
    print(f"batch size: {cfg['batch_size']}")
    print(f"train number of batches: {len(train_dataloader)}")
    print(f"valid number of batches: {len(valid_dataloader)}")

    # --------------------------------------------------------------------------
    # Deep Learning Classification Model
    # --------------------------------------------------------------------------
    # __________________________________________________________________________
    # initialize model
    model = TorchVisionFCModel("resnet18",
                               num_classes=1,
                               use_conv=True,
                               pretrained=cfg["pretrain"])
    model = model.to(device)

    # loss function
    loss_func = torch.nn.BCEWithLogitsLoss()
    loss_func = loss_func.to(device)

    # optimizer
    if cfg["novograd"]:
        optimizer = Novograd(model.parameters(), cfg["lr"])
    else:
        optimizer = SGD(model.parameters(), lr=cfg["lr"], momentum=0.9)

    # AMP scaler
    if cfg["amp"]:
        cfg["amp"] = True if monai.utils.get_torch_version_tuple() >= (
            1, 6) else False
    else:
        cfg["amp"] = False

    scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                               T_max=cfg["n_epochs"])

    # --------------------------------------------
    # Ignite Trainer/Evaluator
    # --------------------------------------------
    # Evaluator
    val_handlers = [
        CheckpointSaver(save_dir=log_dir,
                        save_dict={"net": model},
                        save_key_metric=True),
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=log_dir,
                                output_transform=lambda x: None),
    ]
    val_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])
    evaluator = SupervisedEvaluator(
        device=device,
        val_data_loader=valid_dataloader,
        network=model,
        postprocessing=val_postprocessing,
        key_val_metric={
            "val_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        val_handlers=val_handlers,
        amp=cfg["amp"],
    )

    # Trainer
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        CheckpointSaver(save_dir=cfg["logdir"],
                        save_dict={
                            "net": model,
                            "opt": optimizer
                        },
                        save_interval=1,
                        epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=from_engine(["loss"], first=True)),
        ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
        TensorBoardStatsHandler(log_dir=cfg["logdir"],
                                tag_name="train_loss",
                                output_transform=from_engine(["loss"],
                                                             first=True)),
    ]
    train_postprocessing = Compose([
        ActivationsD(keys="pred", sigmoid=True),
        AsDiscreteD(keys="pred", threshold=0.5)
    ])

    trainer = SupervisedTrainer(
        device=device,
        max_epochs=cfg["n_epochs"],
        train_data_loader=train_dataloader,
        network=model,
        optimizer=optimizer,
        loss_function=loss_func,
        postprocessing=train_postprocessing,
        key_train_metric={
            "train_acc":
            Accuracy(output_transform=from_engine(["pred", "label"]))
        },
        train_handlers=train_handlers,
        amp=cfg["amp"],
    )
    trainer.run()
def get_data(args, batch_size=1, mode="train"):
    # get necessary parameters:
    fold = args.fold
    task_id = args.task_id
    root_dir = args.root_dir
    datalist_path = args.datalist_path
    dataset_path = os.path.join(root_dir, task_name[task_id])
    transform_params = (args.pos_sample_num, args.neg_sample_num,
                        args.num_samples)
    multi_gpu_flag = args.multi_gpu

    transform = get_task_transforms(mode, task_id, *transform_params)
    if mode == "test":
        list_key = "test"
    else:
        list_key = "{}_fold{}".format(mode, fold)
    datalist_name = "dataset_task{}.json".format(task_id)

    property_keys = [
        "name",
        "description",
        "reference",
        "licence",
        "tensorImageSize",
        "modality",
        "labels",
        "numTraining",
        "numTest",
    ]

    datalist = load_decathlon_datalist(
        os.path.join(datalist_path, datalist_name), True, list_key,
        dataset_path)

    properties = load_decathlon_properties(
        os.path.join(datalist_path, datalist_name), property_keys)
    if mode in ["validation", "test"]:
        if multi_gpu_flag:
            datalist = partition_dataset(
                data=datalist,
                shuffle=False,
                num_partitions=dist.get_world_size(),
                even_divisible=False,
            )[dist.get_rank()]

        val_ds = CacheDataset(
            data=datalist,
            transform=transform,
            num_workers=4,
        )

        data_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=args.val_num_workers,
        )
    elif mode == "train":
        if multi_gpu_flag:
            datalist = partition_dataset(
                data=datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]

        train_ds = CacheDataset(
            data=datalist,
            transform=transform,
            num_workers=8,
            cache_rate=args.cache_rate,
        )
        data_loader = DataLoader(
            train_ds,
            batch_size=batch_size,
            shuffle=True,
            num_workers=args.train_num_workers,
            drop_last=True,
        )
    else:
        raise ValueError(f"mode should be train, validation or test.")

    return properties, data_loader