Beispiel #1
0
def get_loaders(args, pre_transforms, train=True):
    multi_gpu = args.multi_gpu
    local_rank = args.local_rank

    dataset_json = os.path.join(args.input)
    with open(dataset_json) as f:
        datalist = json.load(f)

    total_d = len(datalist)
    datalist = datalist[0:args.limit] if args.limit else datalist
    total_l = len(datalist)

    if multi_gpu:
        datalist = partition_dataset(
            data=datalist,
            num_partitions=dist.get_world_size(),
            even_divisible=True,
            shuffle=True,
            seed=args.seed,
        )[local_rank]

    if train:
        train_datalist, val_datalist = partition_dataset(
            datalist,
            ratios=[args.split, (1 - args.split)],
            shuffle=True,
            seed=args.seed,
        )

        train_ds = PersistentDataset(train_datalist,
                                     pre_transforms,
                                     cache_dir=args.cache_dir)
        train_loader = DataLoader(train_ds,
                                  batch_size=args.batch,
                                  shuffle=True,
                                  num_workers=16)
        logging.info(
            "{}:: Total Records used for Training is: {}/{}/{}".format(
                local_rank, len(train_ds), total_l, total_d))
    else:
        train_loader = None
        val_datalist = datalist

    val_ds = PersistentDataset(val_datalist,
                               pre_transforms,
                               cache_dir=args.cache_dir)
    val_loader = DataLoader(val_ds, batch_size=args.batch, num_workers=8)
    logging.info("{}:: Total Records used for Validation is: {}/{}/{}".format(
        local_rank, len(val_ds), total_l, total_d))

    return train_loader, val_loader
    def __init__(
        self,
        dataset: Dataset,
        image_key: Optional[str] = "image",
        label_key: Optional[str] = "label",
        meta_key_postfix: str = "meta_dict",
        num_workers: int = 0,
        **kwargs,
    ):
        """
        Args:
            dataset: dataset from which to load the data.
            image_key: key name of images (default: ``image``).
            label_key: key name of labels (default: ``label``).
            meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict,
                the meta data is a dictionary object (default: ``meta_dict``).
            num_workers: how many subprocesses to use for data loading.
                ``0`` means that the data will be loaded in the main process (default: ``0``).
            kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``).

        """

        self.data_loader = DataLoader(dataset=dataset,
                                      batch_size=1,
                                      num_workers=num_workers,
                                      **kwargs)

        self.image_key = image_key
        self.label_key = label_key
        if image_key:
            self.meta_key = f"{image_key}_{meta_key_postfix}"
        self.all_meta_data: List = []
Beispiel #3
0
    def __init__(
        self,
        dataset: Dataset,
        image_key: Optional[str] = "image",
        label_key: Optional[str] = "label",
        meta_key: Optional[KeysCollection] = None,
        meta_key_postfix: str = DEFAULT_POST_FIX,
        num_workers: int = 0,
        **kwargs,
    ):
        """
        Args:
            dataset: dataset from which to load the data.
            image_key: key name of images (default: ``image``).
            label_key: key name of labels (default: ``label``).
            meta_key: explicitly indicate the key of the corresponding meta data dictionary.
                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
                the meta data is a dictionary object which contains: filename, affine, original_shape, etc.
                if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`.
            meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict,
                the meta data is a dictionary object (default: ``meta_dict``).
            num_workers: how many subprocesses to use for data loading.
                ``0`` means that the data will be loaded in the main process (default: ``0``).
            kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader,
                this class forces to use ``batch_size=1``.

        """

        self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs)

        self.image_key = image_key
        self.label_key = label_key
        self.meta_key = meta_key or f"{image_key}_{meta_key_postfix}"
        self.all_meta_data: List = []
Beispiel #4
0
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float],
               NdarrayOrTensor]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are
                calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)
                is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [deepcopy(d) for _ in range(num_examples)]
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        num_workers=self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        outs: List = []

        for b in tqdm(dl) if has_tqdm and self.progress else dl:
            # do model forward pass
            b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(
                self.device))
            outs.extend([
                self.inverter(PadListDataCollate.inverse(i))[self._pred_key]
                for i in decollate_batch(b)
            ])

        output: NdarrayOrTensor = stack(outs, 0)

        if self.return_full_data:
            return output

        # calculate metrics
        _mode = mode(output, dim=0)
        mean = output.mean(0)
        std = output.std(0)
        vvc = (output.std() / output.mean()).item()

        return _mode, mean, std, vvc
 def __call__(self, data: Dict[str, Any]) -> Any:
     decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)
     inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used)
     inv_loader = DataLoader(
         inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn
     )
     try:
         return first(inv_loader)
     except RuntimeError as re:
         re_str = str(re)
         if "equal size" in re_str:
             re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`."
         raise RuntimeError(re_str) from re
Beispiel #6
0
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated
                across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean`
                across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [d] * num_examples
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX

        # create inverter
        inverter = BatchInverseTransform(self.transform,
                                         dl,
                                         collate_fn=list_data_collate)

        outputs: List[np.ndarray] = []

        for batch_data in tqdm(dl) if self.progress else dl:
            batch_images = batch_data[self.image_key].to(self.device)

            # do model forward pass
            batch_output = self.inferrer_fn(batch_images)
            if isinstance(batch_output, torch.Tensor):
                batch_output = batch_output.detach().cpu()
            if isinstance(batch_output, np.ndarray):
                batch_output = torch.Tensor(batch_output)

            # create a dictionary containing the inferred batch and their transforms
            inferred_dict = {
                self.label_key: batch_output,
                label_transform_key: batch_data[label_transform_key]
            }
            # if meta dict is present, add that too (required for some inverse transforms)
            label_meta_dict_key = self.meta_keys or f"{self.label_key}_{self.meta_key_postfix}"
            if label_meta_dict_key in batch_data:
                inferred_dict[label_meta_dict_key] = batch_data[
                    label_meta_dict_key]

            # do inverse transformation (allow missing keys as only inverting label)
            with allow_missing_keys_mode(self.transform):  # type: ignore
                inv_batch = inverter(inferred_dict)

            # append
            outputs.append(inv_batch[self.label_key])

        # output
        output: np.ndarray = np.concatenate(outputs)

        if self.return_full_data:
            return output

        # calculate metrics
        mode = np.array(
            torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values)
        mean: np.ndarray = np.mean(output, axis=0)  # type: ignore
        std: np.ndarray = np.std(output, axis=0)  # type: ignore
        vvc: float = (np.std(output) / np.mean(output)).item()
        return mode, mean, std, vvc