def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) for ( key, orig_key, meta_key, orig_meta_key, meta_key_postfix, nearest_interp, to_tensor, device, post_func, ) in self.key_iterator( d, self.orig_keys, self.meta_keys, self.orig_meta_keys, self.meta_key_postfix, self.nearest_interp, self.to_tensor, self.device, self.post_func, ): transform_key = f"{orig_key}{InverseKeys.KEY_SUFFIX}" if transform_key not in d: warnings.warn( f"transform info of `{orig_key}` is not available or no InvertibleTransform applied." ) continue transform_info = d[transform_key] if nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None) input = d[key] if isinstance(input, torch.Tensor): input = input.detach() # construct the input dict data for BatchInverseTransform input_dict = {orig_key: input, transform_key: transform_info} orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" meta_key = meta_key or f"{key}_{meta_key_postfix}" if orig_meta_key in d: input_dict[orig_meta_key] = d[orig_meta_key] with allow_missing_keys_mode(self.transform): # type: ignore inverted = self.transform.inverse(input_dict) # save the inverted data d[key] = post_func( self._totensor(inverted[orig_key]). to(device) if to_tensor else inverted[orig_key]) # save the inverted meta dict if orig_meta_key in d: d[meta_key] = inverted.get(orig_meta_key) return d
def __call__( self, data: Dict[str, Any], num_examples: int = 10 ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: """ Args: data: dictionary data to be processed. num_examples: number of realisations to be processed and results combined. Returns: - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. """ d = dict(data) # check num examples is multiple of batch size if num_examples % self.batch_size != 0: raise ValueError("num_examples should be multiple of batch size.") # generate batch of data of size == batch_size, dataset and dataloader data_in = [deepcopy(d) for _ in range(num_examples)] ds = Dataset(data_in, self.transform) dl = DataLoader(ds, self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) transform_key = self.orig_key + InverseKeys.KEY_SUFFIX # create inverter inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) outputs: List[np.ndarray] = [] for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: batch_images = batch_data[self.image_key].to(self.device) # do model forward pass batch_output = self.inferrer_fn(batch_images) if isinstance(batch_output, torch.Tensor): batch_output = batch_output.detach().cpu() if isinstance(batch_output, np.ndarray): batch_output = torch.Tensor(batch_output) transform_info = batch_data[transform_key] if self.nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None) # create a dictionary containing the inferred batch and their transforms inferred_dict = { self.orig_key: batch_output, transform_key: transform_info } # if meta dict is present, add that too (required for some inverse transforms) meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}" if meta_dict_key in batch_data: inferred_dict[meta_dict_key] = batch_data[meta_dict_key] # do inverse transformation (allow missing keys as only inverting the orig_key) with allow_missing_keys_mode(self.transform): # type: ignore inv_batch = inverter(inferred_dict) # append outputs.append(inv_batch[self.orig_key]) # output output: np.ndarray = np.concatenate(outputs) if self.return_full_data: return output # calculate metrics mode = np.array( torch.mode(torch.Tensor(output.astype(np.int64)), dim=0).values) mean: np.ndarray = np.mean(output, axis=0) # type: ignore std: np.ndarray = np.std(output, axis=0) # type: ignore vvc: float = (np.std(output) / np.mean(output)).item() return mode, mean, std, vvc