def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): """ Change the interpolation mode when inverting spatial transforms, default to "nearest". This function modifies trans_info's `InverseKeys.EXTRA_INFO`. See also: :py:class:`monai.transform.inverse.InvertibleTransform` Args: trans_info: transforms inverse information list, contains context of every invertible transform. mode: target interpolation mode to convert, default to "nearest" as it's usually used to save the mode output. align_corners: target align corner value in PyTorch interpolation API, need to align with the `mode`. """ interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] # set to string for DataLoader collation align_corners_ = "none" if align_corners is None else align_corners for item in ensure_tuple(trans_info): if InverseKeys.EXTRA_INFO in item: orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) if orig_mode is not None: if orig_mode[0] in interp_modes: item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] elif orig_mode in interp_modes: item[InverseKeys.EXTRA_INFO]["mode"] = mode if "align_corners" in item[InverseKeys.EXTRA_INFO]: if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ return trans_info
def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]): """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. Args: transform: either MapTransform or a Compose Example: .. code-block:: python data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False) _ = t(data) # would raise exception with allow_missing_keys_mode(t): _ = t(data) # OK! """ # If given a sequence of transforms, Compose them to get a single list if issequenceiterable(transform): transform = Compose(transform) # Get list of MapTransforms transforms = [] if isinstance(transform, MapTransform): transforms = [transform] elif isinstance(transform, Compose): # Only keep contained MapTransforms transforms = [ t for t in transform.flatten().transforms if isinstance(t, MapTransform) ] if len(transforms) == 0: raise TypeError( "allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)" ) # Get the state of each `allow_missing_keys` orig_states = [t.allow_missing_keys for t in transforms] try: # Set all to True for t in transforms: t.allow_missing_keys = True yield finally: # Revert for t, o_s in zip(transforms, orig_states): t.allow_missing_keys = o_s
def _detect_batch_size(batch_data: Sequence): """ Detect the batch size from a list of data, some items in the list have batch dim, some not. """ for v in batch_data: if isinstance(v, torch.Tensor) and v.ndim > 0: return v.shape[0] for v in batch_data: if issequenceiterable(v): warnings.warn( "batch_data doesn't contain batched Tensor data, use the length of first sequence data." ) return len(v) raise RuntimeError("failed to automatically detect the batch size.")
def __call__(self, engine: Engine) -> None: """ This method assumes self.batch_transform will extract metadata from the input batch. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ filenames = self.batch_transform(engine.state.batch).get( Key.FILENAME_OR_OBJ) if issequenceiterable(filenames): self._filenames.extend(filenames) outputs = self.output_transform(engine.state.output) if outputs is not None: if isinstance(outputs, torch.Tensor): outputs = outputs.detach() self._outputs.append(outputs)
def partition_dataset_classes( data: Sequence, classes: Sequence[int], ratios: Optional[Sequence[float]] = None, num_partitions: Optional[int] = None, shuffle: bool = False, seed: int = 0, drop_last: bool = False, even_divisible: bool = False, ): """ Split the dataset into N partitions based on the given class labels. It can make sure the same ratio of classes in every partition. Others are same as :py:class:`monai.data.partition_dataset`. Args: data: input dataset to split, expect a list of data. classes: a list of labels to help split the data, the length must match the length of data. ratios: a list of ratio number to split the dataset, like [8, 1, 1]. num_partitions: expected number of the partitions to evenly split, only works when no `ratios`. shuffle: whether to shuffle the original dataset before splitting. seed: random seed to shuffle the dataset, only works when `shuffle` is True. drop_last: only works when `even_divisible` is False and no ratios specified. if True, will drop the tail of the data to make it evenly divisible across partitions. if False, will add extra indices to make the data evenly divisible across partitions. even_divisible: if True, guarantee every partition has same length. Examples:: >>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] >>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3] >>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1]) [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]] """ if not issequenceiterable(classes) or len(classes) != len(data): raise ValueError( f"length of classes {classes} must match the dataset length {len(data)}." ) datasets = [] class_indices = defaultdict(list) for i, c in enumerate(classes): class_indices[c].append(i) class_partition_indices: List[Sequence] = [] for _, per_class_indices in sorted(class_indices.items()): per_class_partition_indices = partition_dataset( data=per_class_indices, ratios=ratios, num_partitions=num_partitions, shuffle=shuffle, seed=seed, drop_last=drop_last, even_divisible=even_divisible, ) if not class_partition_indices: class_partition_indices = per_class_partition_indices else: for part, data_indices in zip(class_partition_indices, per_class_partition_indices): part += data_indices rs = np.random.RandomState(seed) for indices in class_partition_indices: if shuffle: rs.shuffle(indices) datasets.append([data[j] for j in indices]) return datasets
def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: filenames = self.batch_transform(engine.state.batch).get( Key.FILENAME_OR_OBJ) if issequenceiterable(filenames): self._filenames.extend(filenames)