Ejemplo n.º 1
0
    def __init__(
        self,
        root_dir: str,
        task: str,
        section: str,
        transform: Union[Sequence[Callable], Callable] = (),
        download: bool = False,
        seed: int = 0,
        val_frac: float = 0.2,
        cache_num: int = sys.maxsize,
        cache_rate: float = 1.0,
        num_workers: int = 0,
    ) -> None:
        if not os.path.isdir(root_dir):
            raise ValueError("Root directory root_dir must be a directory.")
        self.section = section
        self.val_frac = val_frac
        self.set_random_state(seed=seed)
        if task not in self.resource:
            raise ValueError(
                f"Unsupported task: {task}, available options are: {list(self.resource.keys())}."
            )
        dataset_dir = os.path.join(root_dir, task)
        tarfile_name = f"{dataset_dir}.tar"
        if download:
            download_and_extract(self.resource[task], tarfile_name, root_dir,
                                 self.md5[task])

        if not os.path.exists(dataset_dir):
            raise RuntimeError(
                f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
            )
        self.indices: np.ndarray = np.array([])
        data = self._generate_data_list(dataset_dir)
        # as `release` key has typo in Task04 config file, ignore it.
        property_keys = [
            "name",
            "description",
            "reference",
            "licence",
            "tensorImageSize",
            "modality",
            "labels",
            "numTraining",
            "numTest",
        ]
        self._properties = load_decathlon_properties(
            os.path.join(dataset_dir, "dataset.json"), property_keys)
        if transform == ():
            transform = LoadImaged(["image", "label"])
        CacheDataset.__init__(self,
                              data,
                              transform,
                              cache_num=cache_num,
                              cache_rate=cache_rate,
                              num_workers=num_workers)
Ejemplo n.º 2
0
def get_task_params(args):
    """
    This function is used to achieve the spacings of decathlon dataset.
    In addition, for CT images (task 03, 06, 07, 08, 09 and 10), this function
    also prints the mean and std values (used for normalization), and the min (0.5 percentile)
    and max(99.5 percentile) values (used for clip).

    """
    task_id = args.task_id
    root_dir = args.root_dir
    datalist_path = args.datalist_path
    dataset_path = os.path.join(root_dir, task_name[task_id])
    datalist_name = "dataset_task{}.json".format(task_id)

    # get all training data
    datalist = load_decathlon_datalist(
        os.path.join(datalist_path, datalist_name), True, "training", dataset_path
    )

    # get modality info.
    properties = load_decathlon_properties(
        os.path.join(datalist_path, datalist_name), "modality"
    )

    dataset = Dataset(
        data=datalist,
        transform=LoadImaged(keys=["image", "label"]),
    )

    calculator = DatasetSummary(dataset, num_workers=4)
    target_spacing = calculator.get_target_spacing()
    print("spacing: ", target_spacing)
    if properties["modality"]["0"] == "CT":
        print("CT input, calculate statistics:")
        calculator.calculate_statistics()
        print("mean: ", calculator.data_mean, " std: ", calculator.data_std)
        calculator.calculate_percentiles(
            sampling_flag=True, interval=10, min_percentile=0.5, max_percentile=99.5
        )
        print(
            "min: ",
            calculator.data_min_percentile,
            " max: ",
            calculator.data_max_percentile,
        )
    else:
        print("non CT input, skip calculating.")
Ejemplo n.º 3
0
def get_data(args, batch_size=1, mode="train"):
    # get necessary parameters:
    fold = args.fold
    task_id = args.task_id
    root_dir = args.root_dir
    datalist_path = args.datalist_path
    dataset_path = os.path.join(root_dir, task_name[task_id])
    transform_params = (args.pos_sample_num, args.neg_sample_num,
                        args.num_samples)
    multi_gpu_flag = args.multi_gpu

    transform = get_task_transforms(mode, task_id, *transform_params)
    if mode == "test":
        list_key = "test"
    else:
        list_key = "{}_fold{}".format(mode, fold)
    datalist_name = "dataset_task{}.json".format(task_id)

    property_keys = [
        "name",
        "description",
        "reference",
        "licence",
        "tensorImageSize",
        "modality",
        "labels",
        "numTraining",
        "numTest",
    ]

    datalist = load_decathlon_datalist(
        os.path.join(datalist_path, datalist_name), True, list_key,
        dataset_path)

    properties = load_decathlon_properties(
        os.path.join(datalist_path, datalist_name), property_keys)
    if mode in ["validation", "test"]:
        if multi_gpu_flag:
            datalist = partition_dataset(
                data=datalist,
                shuffle=False,
                num_partitions=dist.get_world_size(),
                even_divisible=False,
            )[dist.get_rank()]

        val_ds = CacheDataset(
            data=datalist,
            transform=transform,
            num_workers=4,
        )

        data_loader = DataLoader(
            val_ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=args.val_num_workers,
        )
    elif mode == "train":
        if multi_gpu_flag:
            datalist = partition_dataset(
                data=datalist,
                shuffle=True,
                num_partitions=dist.get_world_size(),
                even_divisible=True,
            )[dist.get_rank()]

        train_ds = CacheDataset(
            data=datalist,
            transform=transform,
            num_workers=8,
            cache_rate=args.cache_rate,
        )
        data_loader = DataLoader(
            train_ds,
            batch_size=batch_size,
            shuffle=True,
            num_workers=args.train_num_workers,
            drop_last=True,
        )
    else:
        raise ValueError(f"mode should be train, validation or test.")

    return properties, data_loader