Ejemplo n.º 1
0
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated
                across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean`
                across the whole output, including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
                concatenating across the first dimension containing `num_examples`. This allows the user to perform
                their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [d] * num_examples
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        label_transform_key = self.label_key + InverseKeys.KEY_SUFFIX

        # create inverter
        inverter = BatchInverseTransform(self.transform,
                                         dl,
                                         collate_fn=list_data_collate)

        outputs: List[np.ndarray] = []

        for batch_data in tqdm(dl) if self.progress else dl:
            batch_images = batch_data[self.image_key].to(self.device)

            # do model forward pass
            batch_output = self.inferrer_fn(batch_images)
            if isinstance(batch_output, torch.Tensor):
                batch_output = batch_output.detach().cpu()
            if isinstance(batch_output, np.ndarray):
                batch_output = torch.Tensor(batch_output)

            # create a dictionary containing the inferred batch and their transforms
            inferred_dict = {
                self.label_key: batch_output,
                label_transform_key: batch_data[label_transform_key]
            }
            # if meta dict is present, add that too (required for some inverse transforms)
            label_meta_dict_key = self.meta_keys or f"{self.label_key}_{self.meta_key_postfix}"
            if label_meta_dict_key in batch_data:
                inferred_dict[label_meta_dict_key] = batch_data[
                    label_meta_dict_key]

            # do inverse transformation (allow missing keys as only inverting label)
            with allow_missing_keys_mode(self.transform):  # type: ignore
                inv_batch = inverter(inferred_dict)

            # append
            outputs.append(inv_batch[self.label_key])

        # output
        output: np.ndarray = np.concatenate(outputs)

        if self.return_full_data:
            return output

        # calculate metrics
        mode = np.array(
            torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values)
        mean: np.ndarray = np.mean(output, axis=0)  # type: ignore
        std: np.ndarray = np.std(output, axis=0)  # type: ignore
        vvc: float = (np.std(output) / np.mean(output)).item()
        return mode, mean, std, vvc
Ejemplo n.º 2
0
    def __init__(
        self,
        keys: KeysCollection,
        transform: InvertibleTransform,
        loader: TorchDataLoader,
        orig_keys: Union[str, Sequence[str]],
        meta_key_postfix: str = "meta_dict",
        collate_fn: Optional[Callable] = no_collation,
        postfix: str = "inverted",
        nearest_interp: Union[bool, Sequence[bool]] = True,
        to_tensor: Union[bool, Sequence[bool]] = True,
        device: Union[Union[str, torch.device],
                      Sequence[Union[str, torch.device]]] = "cpu",
        post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
        num_workers: Optional[int] = 0,
        allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: the key of expected data in the dict, invert transforms on it.
                it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"].
            transform: the previous callable transform that applied on input data.
            loader: data loader used to run transforms and generate the batch of data.
            orig_keys: the key of the original input data in the dict. will get the applied transform information
                for this input data, then invert them for the expected data with `keys`.
                It can also be a list of keys, each matches to the `keys` data.
            meta_key_postfix: use `{orig_key}_{postfix}` to to fetch the meta data from dict,
                default is `meta_dict`, the meta data is a dictionary object.
                For example, to handle orig_key `image`,  read/write `affine` matrices from the
                metadata `image_meta_dict` dictionary's `affine` field.
            collate_fn: how to collate data after inverse transformations. default won't do any collation,
                so the output will be a list of PyTorch Tensor or numpy array without batch dim.
            postfix: will save the inverted result into dict with key `{key}_{postfix}`.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                it also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                it also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to "cpu", it also can be a list of string or `torch.device`,
                each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                it also can be a list of callable, each matches to the `keys` data.
            num_workers: number of workers when run data loader for inverse transforms,
                default to 0 as only run one iteration and multi-processing may be even slower.
                Set to `None`, to use the `num_workers` of the input transform data loader.
            allow_missing_keys: don't raise exception if key is missing.

        """
        super().__init__(keys, allow_missing_keys)
        self.transform = transform
        self.inverter = BatchInverseTransform(
            transform=transform,
            loader=loader,
            collate_fn=collate_fn,
            num_workers=num_workers,
        )
        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys))
        self.meta_key_postfix = meta_key_postfix
        self.postfix = postfix
        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))
        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))
        self.device = ensure_tuple_rep(device, len(self.keys))
        self.post_func = ensure_tuple_rep(post_func, len(self.keys))
        self._totensor = ToTensor()