Example #1
0
    def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
        d = dict(data)
        for (
                key,
                orig_key,
                meta_key,
                orig_meta_key,
                meta_key_postfix,
                nearest_interp,
                to_tensor,
                device,
                post_func,
        ) in self.key_iterator(
                d,
                self.orig_keys,
                self.meta_keys,
                self.orig_meta_keys,
                self.meta_key_postfix,
                self.nearest_interp,
                self.to_tensor,
                self.device,
                self.post_func,
        ):
            transform_key = f"{orig_key}{InverseKeys.KEY_SUFFIX}"
            if transform_key not in d:
                warnings.warn(
                    f"transform info of `{orig_key}` is not available or no InvertibleTransform applied."
                )
                continue

            transform_info = d[transform_key]
            if nearest_interp:
                transform_info = convert_inverse_interp_mode(
                    trans_info=deepcopy(transform_info),
                    mode="nearest",
                    align_corners=None)

            input = d[key]
            if isinstance(input, torch.Tensor):
                input = input.detach()
            # construct the input dict data for BatchInverseTransform
            input_dict = {orig_key: input, transform_key: transform_info}
            orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
            meta_key = meta_key or f"{key}_{meta_key_postfix}"
            if orig_meta_key in d:
                input_dict[orig_meta_key] = d[orig_meta_key]

            with allow_missing_keys_mode(self.transform):  # type: ignore
                inverted = self.transform.inverse(input_dict)

            # save the inverted data
            d[key] = post_func(
                self._totensor(inverted[orig_key]).
                to(device) if to_tensor else inverted[orig_key])
            # save the inverted meta dict
            if orig_meta_key in d:
                d[meta_key] = inverted.get(orig_meta_key)

        return d
    def __call__(
        self,
        data: Dict[str, Any],
        num_examples: int = 10
    ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
        """
        Args:
            data: dictionary data to be processed.
            num_examples: number of realisations to be processed and results combined.

        Returns:
            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across
                `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output,
                including `num_examples`. See original paper for clarification.
            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across
                the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired.
        """
        d = dict(data)

        # check num examples is multiple of batch size
        if num_examples % self.batch_size != 0:
            raise ValueError("num_examples should be multiple of batch size.")

        # generate batch of data of size == batch_size, dataset and dataloader
        data_in = [deepcopy(d) for _ in range(num_examples)]
        ds = Dataset(data_in, self.transform)
        dl = DataLoader(ds,
                        self.num_workers,
                        batch_size=self.batch_size,
                        collate_fn=pad_list_data_collate)

        transform_key = self.orig_key + InverseKeys.KEY_SUFFIX

        # create inverter
        inverter = BatchInverseTransform(self.transform,
                                         dl,
                                         collate_fn=list_data_collate)

        outputs: List[np.ndarray] = []

        for batch_data in tqdm(dl) if has_tqdm and self.progress else dl:

            batch_images = batch_data[self.image_key].to(self.device)

            # do model forward pass
            batch_output = self.inferrer_fn(batch_images)
            if isinstance(batch_output, torch.Tensor):
                batch_output = batch_output.detach().cpu()
            if isinstance(batch_output, np.ndarray):
                batch_output = torch.Tensor(batch_output)

            transform_info = batch_data[transform_key]
            if self.nearest_interp:
                transform_info = convert_inverse_interp_mode(
                    trans_info=deepcopy(transform_info),
                    mode="nearest",
                    align_corners=None)

            # create a dictionary containing the inferred batch and their transforms
            inferred_dict = {
                self.orig_key: batch_output,
                transform_key: transform_info
            }
            # if meta dict is present, add that too (required for some inverse transforms)
            meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}"
            if meta_dict_key in batch_data:
                inferred_dict[meta_dict_key] = batch_data[meta_dict_key]

            # do inverse transformation (allow missing keys as only inverting the orig_key)
            with allow_missing_keys_mode(self.transform):  # type: ignore
                inv_batch = inverter(inferred_dict)

            # append
            outputs.append(inv_batch[self.orig_key])

        # output
        output: np.ndarray = np.concatenate(outputs)

        if self.return_full_data:
            return output

        # calculate metrics
        mode = np.array(
            torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values)
        mean: np.ndarray = np.mean(output, axis=0)  # type: ignore
        std: np.ndarray = np.std(output, axis=0)  # type: ignore
        vvc: float = (np.std(output) / np.mean(output)).item()
        return mode, mean, std, vvc