Esempio n. 1
0
    def __init__(
        self,
        keys: KeysCollection,
        range_x=0.0,
        range_y=0.0,
        range_z=0.0,
        prob: float = 0.1,
        keep_size: bool = True,
        interp_order: str = "bilinear",
        mode: str = "border",
        align_corners: bool = False,
    ):
        super().__init__(keys)
        self.range_x = ensure_tuple(range_x)
        if len(self.range_x) == 1:
            self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]]))
        self.range_y = ensure_tuple(range_y)
        if len(self.range_y) == 1:
            self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]]))
        self.range_z = ensure_tuple(range_z)
        if len(self.range_z) == 1:
            self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]]))

        self.prob = prob
        self.keep_size = keep_size
        self.interp_order = ensure_tuple_rep(interp_order, len(self.keys))
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.align_corners = align_corners

        self._do_transform = False
        self.x = 0.0
        self.y = 0.0
        self.z = 0.0
Esempio n. 2
0
def select_labels(labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor],
                  keep: NdarrayOrTensor) -> Union[Tuple, NdarrayOrTensor]:
    """
    For element in labels, select indice keep from it.

    Args:
        labels: Sequence of array. Each element represents classification labels or scores
            corresponding to ``boxes``, sized (N,).
        keep: the indices to keep, same length with each element in labels.

    Return:
        selected labels, does not share memory with original labels.
    """
    labels_tuple = ensure_tuple(labels, True)

    labels_select_list = []
    keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]
    for i in range(len(labels_tuple)):
        labels_t: torch.Tensor = convert_data_type(labels_tuple[i],
                                                   torch.Tensor)[0]
        labels_t = labels_t[keep_t, ...]
        labels_select_list.append(
            convert_to_dst_type(src=labels_t, dst=labels_tuple[i])[0])

    if isinstance(labels, (torch.Tensor, np.ndarray)):
        return labels_select_list[0]  # type: ignore

    return tuple(labels_select_list)
Esempio n. 3
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = None
        for name in filename:
            img = Image.open(name)
            data = np.asarray(img)
            if self.dtype:
                data = data.astype(self.dtype)
            img_array.append(data)
            meta = dict()
            meta['filename_or_obj'] = name
            meta['spatial_shape'] = data.shape[:2]
            meta['format'] = img.format
            meta['mode'] = img.mode
            meta['width'] = img.width
            meta['height'] = img.height
            meta['info'] = img.info

            if self.image_only:
                continue

            if not compatible_meta:
                compatible_meta = meta
            else:
                assert np.allclose(meta['spatial_shape'], compatible_meta['spatial_shape']), \
                    'all the images in the list should have same spatial shape.'

        img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
        return img_array if self.image_only else (img_array, compatible_meta)
Esempio n. 4
0
def flip_boxes(
    boxes: NdarrayOrTensor,
    spatial_size: Union[Sequence[int], int],
    flip_axes: Optional[Union[Sequence[int], int]] = None,
) -> NdarrayOrTensor:
    """
    Flip boxes when the corresponding image is flipped

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        spatial_size: image spatial size.
        flip_axes: spatial axes along which to flip over. Default is None.
            The default `axis=None` will flip over all of the axes of the input array.
            If axis is negative it counts from the last to the first axis.
            If axis is a tuple of ints, flipping is performed on all of the axes
            specified in the tuple.

    Returns:
        flipped boxes, with same data type as ``boxes``, does not share memory with ``boxes``
    """
    spatial_dims: int = get_spatial_dims(boxes=boxes)
    spatial_size = ensure_tuple_rep(spatial_size, spatial_dims)
    if flip_axes is None:
        flip_axes = tuple(range(0, spatial_dims))
    flip_axes = ensure_tuple(flip_axes)

    # flip box
    _flip_boxes = deepcopy(boxes)
    for axis in flip_axes:
        _flip_boxes[:, axis + spatial_dims] = spatial_size[axis] - boxes[:, axis] - TO_REMOVE
        _flip_boxes[:, axis] = spatial_size[axis] - boxes[:, axis + spatial_dims] - TO_REMOVE

    return _flip_boxes
Esempio n. 5
0
def create_shear(spatial_dims, coefs):
    """
    create a shearing matrix
    Args:
        spatial_dims (int): spatial rank
        coefs (floats): shearing factors, defaults to 0.
    """
    coefs = list(ensure_tuple(coefs))
    if spatial_dims == 2:
        while len(coefs) < 2:
            coefs.append(0.0)
        return np.array([
            [1, coefs[0], 0.],
            [coefs[1], 1., 0.],
            [0., 0., 1.],
        ])
    if spatial_dims == 3:
        while len(coefs) < 6:
            coefs.append(0.0)
        return np.array([
            [1., coefs[0], coefs[1], 0.],
            [coefs[2], 1., coefs[3], 0.],
            [coefs[4], coefs[5], 1., 0.],
            [0., 0., 0., 1.],
        ])
    raise NotImplementedError
Esempio n. 6
0
 def __init__(self,
              keys,
              affine_key,
              pixdim,
              interp_order=2,
              keep_shape=False,
              output_key='spacing'):
     """
     Args:
         affine_key (hashable): the key to the original affine.
             The affine will be used to compute input data's pixdim.
         pixdim (sequence of floats): output voxel spacing.
         interp_order (int or sequence of ints): int: the same interpolation order
             for all data indexed by `self,keys`; sequence of ints, should
             correspond to an interpolation order for each data item indexed
             by `self.keys` respectively.
         keep_shape (bool): whether to maintain the original spatial shape
             after resampling. Defaults to False.
         output_key (hashable): key to be added to the output dictionary to track
             the pixdim status.
     """
     MapTransform.__init__(self, keys)
     self.affine_key = affine_key
     self.spacing_transform = Spacing(pixdim, keep_shape=keep_shape)
     interp_order = ensure_tuple(interp_order)
     self.interp_order = interp_order \
         if len(interp_order) == len(self.keys) else interp_order * len(self.keys)
     self.output_key = output_key
Esempio n. 7
0
    def __init__(self,
                 pixdim,
                 diagonal=False,
                 mode='constant',
                 cval=0,
                 dtype=None):
        """
        Args:
            pixdim (sequence of floats): output voxel spacing.
            diagonal (bool): whether to resample the input to have a diagonal affine matrix.
                If True, the input data is resampled to the following affine::

                    np.diag((pixdim_0, pixdim_1, ..., pixdim_n, 1))

                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).
                The original orientation, rotation, shearing are not preserved.

                If False, this transform preserves the axes orientation, orthogonal rotation and
                translation components from the original affine. This option will not flip/swap axes
                of the original data.
            mode (`reflect|constant|nearest|mirror|wrap`):
                The mode parameter determines how the input array is extended beyond its boundaries.
            cval (scalar): Value to fill past edges of input if mode is "constant". Default is 0.0.
            dtype (None or np.dtype): output array data type, defaults to None to use input data's dtype.
        """
        self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
        self.diagonal = diagonal
        self.mode = mode
        self.cval = cval
        self.dtype = dtype
Esempio n. 8
0
 def randomize(self, img_size):
     self._size = [self.roi_size] * len(img_size) if not isinstance(self.roi_size, (list, tuple)) else self.roi_size
     if self.random_size:
         self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))]
     if self.random_center:
         valid_size = get_valid_patch_size(img_size, self._size)
         self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
Esempio n. 9
0
    def __init__(
        self, keys, pixdim, diagonal=False, mode="nearest", cval=0, interp_order=3, dtype=None, meta_key_format="{}.{}"
    ):
        """
        Args:
            pixdim (sequence of floats): output voxel spacing.
            diagonal (bool): whether to resample the input to have a diagonal affine matrix.
                If True, the input data is resampled to the following affine::

                    np.diag((pixdim_0, pixdim_1, pixdim_2, 1))

                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).
                The original orientation, rotation, shearing are not preserved.

                If False, the axes orientation, orthogonal rotation and
                translations components from the original affine will be
                preserved in the target affine. This option will not flip/swap
                axes against the original ones.
            mode (`reflect|constant|nearest|mirror|wrap`):
                The mode parameter determines how the input array is extended beyond its boundaries.
                Default is 'nearest'.
            cval (scalar): Value to fill past edges of input if mode is "constant". Default is 0.0.
            interp_order (int or sequence of ints): int: the same interpolation order
                for all data indexed by `self.keys`; sequence of ints, should
                correspond to an interpolation order for each data item indexed
                by `self.keys` respectively.
            dtype (None or np.dtype): output array data type, defaults to None to use input data's dtype.
            meta_key_format (str): key format to read/write affine matrices to the data dictionary.
        """
        super().__init__(keys)
        self.spacing_transform = Spacing(pixdim, diagonal=diagonal, mode=mode, cval=cval, dtype=dtype)
        interp_order = ensure_tuple(interp_order)
        self.interp_order = interp_order if len(interp_order) == len(self.keys) else interp_order * len(self.keys)
        self.meta_key_format = meta_key_format
Esempio n. 10
0
 def __init__(self, keys):
     self.keys = ensure_tuple(keys)
     if not self.keys:
         raise ValueError("keys unspecified")
     for key in self.keys:
         if not isinstance(key, Hashable):
             raise ValueError(f"keys should be a hashable or a sequence of hashables, got {type(key)}")
Esempio n. 11
0
def generate_spatial_bounding_box(img,
                                  select_fn=lambda x: x > 0,
                                  channel_indexes=None,
                                  margin=0):
    """
    generate the spatial bounding box of foreground in the image with start-end positions.
    Users can define arbitrary function to select expected foreground from the whole image or specified channels.
    And it can also add margin to every dim of the bounding box.

    Args:
        img (ndarrary): source image to generate bounding box from.
        select_fn (Callable): function to select expected foreground, default is to select values > 0.
        channel_indexes (int, tuple or list): if defined, select foreground only on the specified channels
            of image. if None, select foreground on the whole image.
        margin (int): add margin to all dims of the bounding box.
    """
    assert isinstance(margin, int), "margin must be int type."
    data = img[[*(ensure_tuple(channel_indexes))
                ]] if channel_indexes is not None else img
    data = np.any(select_fn(data), axis=0)
    nonzero_idx = np.nonzero(data)

    box_start = list()
    box_end = list()
    for i in range(data.ndim):
        assert len(nonzero_idx[i]
                   ) > 0, f"did not find nonzero index at spatial dim {i}"
        box_start.append(max(0, np.min(nonzero_idx[i]) - margin))
        box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1))
    return box_start, box_end
Esempio n. 12
0
 def randomize(self, img_size):
     self._size = ensure_tuple_rep(self.roi_size, len(img_size))
     if self.random_size:
         self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))]
     if self.random_center:
         valid_size = get_valid_patch_size(img_size, self._size)
         self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
Esempio n. 13
0
    def __init__(
        self,
        rotate_range=None,
        shear_range=None,
        translate_range=None,
        scale_range=None,
        as_tensor_output=True,
        device=None,
    ):
        """
        Args:
            rotate_range (a sequence of positive floats): rotate_range[0] with be used to generate the 1st rotation
                parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[2]` and
                `rotate_range[3]` are used in 3D affine for the range of 2nd and 3rd axes.
            shear_range (a sequence of positive floats): shear_range[0] with be used to generate the 1st shearing
                parameter from `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` to
                `shear_range[N]` controls the range of the uniform distribution used to generate the 2nd to
                N-th parameter.
            translate_range (a sequence of positive floats): translate_range[0] with be used to generate the 1st
                shift parameter from `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]`
                to `translate_range[N]` controls the range of the uniform distribution used to generate
                the 2nd to N-th parameter.
            scale_range (a sequence of positive floats): scaling_range[0] with be used to generate the 1st scaling
                factor from `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` to
                `scale_range[N]` controls the range of the uniform distribution used to generate the 2nd to
                N-th parameter.

        See also:
            - :py:meth:`monai.transforms.utils.create_rotate`
            - :py:meth:`monai.transforms.utils.create_shear`
            - :py:meth:`monai.transforms.utils.create_translate`
            - :py:meth:`monai.transforms.utils.create_scale`
        """
        self.rotate_range = ensure_tuple(rotate_range)
        self.shear_range = ensure_tuple(shear_range)
        self.translate_range = ensure_tuple(translate_range)
        self.scale_range = ensure_tuple(scale_range)

        self.rotate_params = None
        self.shear_params = None
        self.translate_params = None
        self.scale_params = None

        self.as_tensor_output = as_tensor_output
        self.device = device
Esempio n. 14
0
 def __init__(self,
              spatial_size,
              method: str = "symmetric",
              mode: str = "constant"):
     self.spatial_size = ensure_tuple(spatial_size)
     assert method in ("symmetric", "end"), "unsupported padding type."
     self.method = method
     assert isinstance(mode, str), "mode must be str."
     self.mode = mode
Esempio n. 15
0
    def test_value(self, input, expected_value, wrap_array=False):
        result = ensure_tuple(input, wrap_array)

        self.assertTrue(isinstance(result, tuple))
        if isinstance(input, (np.ndarray, torch.Tensor)):
            for i, j in zip(result, expected_value):
                assert_allclose(i, j)
        else:
            self.assertTupleEqual(result, expected_value)
Esempio n. 16
0
def rot90_boxes(boxes: NdarrayOrTensor,
                spatial_size: Union[Sequence[int], int],
                k: int = 1,
                axes: Tuple[int, int] = (0, 1)):
    """
    Rotate boxes by 90 degrees in the plane specified by axes.
    Rotation direction is from the first towards the second axis.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        spatial_size: image spatial size.
        k : number of times the array is rotated by 90 degrees.
        axes: (2,) array_like
            The array is rotated in the plane defined by the axes. Axes must be different.

    Returns:
        A rotated view of `boxes`.

    Notes:
        ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))``  is the reverse of
        ``rot90_boxes(boxes, spatial_size, k=1, axes=(0,1))``
        ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is equivalent to
        ``rot90_boxes(boxes, spatial_size, k=-1, axes=(0,1))``
    """
    spatial_dims: int = get_spatial_dims(boxes=boxes)
    spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims))

    axes = ensure_tuple(axes)  # type: ignore

    if len(axes) != 2:
        raise ValueError("len(axes) must be 2.")

    if axes[0] == axes[1] or abs(axes[0] - axes[1]) == spatial_dims:
        raise ValueError("Axes must be different.")

    if axes[0] >= spatial_dims or axes[0] < -spatial_dims or axes[
            1] >= spatial_dims or axes[1] < -spatial_dims:
        raise ValueError(
            f"Axes={axes} out of range for array of ndim={spatial_dims}.")

    k %= 4

    if k == 0:
        return boxes
    if k == 2:
        return flip_boxes(flip_boxes(boxes, spatial_size_, axes[0]),
                          spatial_size_, axes[1])

    if k == 1:
        boxes_ = flip_boxes(boxes, spatial_size_, axes[1])
        return swapaxes_boxes(boxes_, axes[0], axes[1])
    else:
        # k == 3
        boxes_ = swapaxes_boxes(boxes, axes[0], axes[1])
        spatial_size_[axes[0]], spatial_size_[axes[1]] = spatial_size_[
            axes[1]], spatial_size_[axes[0]]
        return flip_boxes(boxes_, spatial_size_, axes[1])
Esempio n. 17
0
 def __init__(self, select_fn=lambda x: x > 0, channel_indexes=None, margin=0):
     """
     Args:
         select_fn (Callable): function to select expected foreground, default is to select values > 0.
         channel_indexes (int, tuple or list): if defined, select foregound only on the specified channels
             of image. if None, select foreground on the whole image.
         margin (int): add margin to all dims of the bounding box.
     """
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Esempio n. 18
0
 def __init__(self, data, transform=None):
     """
     Args:
         data (Iterable): input data to load and transform to generate dataset for model.
         transform (Callable, optional): transforms to execute operations on input data.
     """
     self.data = data
     if isinstance(transform, Compose):
         self.transform = transform
     else:
         self.transform = Compose(ensure_tuple(transform))
Esempio n. 19
0
def create_scale(spatial_dims, scaling_factor):
    """
    create a scaling matrix
    Args:
        spatial_dims (int): spatial rank
        scaling_factor (floats): scaling factors, defaults to 1.
    """
    scaling_factor = list(ensure_tuple(scaling_factor))
    while len(scaling_factor) < spatial_dims:
        scaling_factor.append(1.0)
    return np.diag(scaling_factor[:spatial_dims] + [1.0])
Esempio n. 20
0
def create_translate(spatial_dims, shift):
    """
    create a translation matrix
    Args:
        spatial_dims (int): spatial rank
        shift (floats): translate factors, defaults to 0.
    """
    shift = ensure_tuple(shift)
    affine = np.eye(spatial_dims + 1)
    for i, a in enumerate(shift[:spatial_dims]):
        affine[i, spatial_dims] = a
    return affine
Esempio n. 21
0
 def test_numpy_values(self, keys, times, names):
     input_data = {
         "img": np.array([[0, 1], [1, 2]]),
         "seg": np.array([[0, 1], [1, 2]])
     }
     result = CopyItemsd(keys=keys, times=times, names=names)(input_data)
     for name in ensure_tuple(names):
         self.assertTrue(name in result)
         result[name] += 1
         np.testing.assert_allclose(result[name], np.array([[1, 2], [2,
                                                                     3]]))
     np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]]))
Esempio n. 22
0
    def __call__(self, filename):
        """
        Args:
            filename (str, list, tuple, file): path file or file-like object or a list of files.
        """
        filename = ensure_tuple(filename)
        img_array = list()
        compatible_meta = dict()
        for name in filename:
            img = nib.load(name)
            img = correct_nifti_header_if_necessary(img)
            header = dict(img.header)
            header["filename_or_obj"] = name
            header["affine"] = img.affine
            header["original_affine"] = img.affine.copy()
            header["as_closest_canonical"] = self.as_closest_canonical
            ndim = img.header["dim"][0]
            spatial_rank = min(ndim, 3)
            header["spatial_shape"] = img.header["dim"][1:spatial_rank + 1]

            if self.as_closest_canonical:
                img = nib.as_closest_canonical(img)
                header["affine"] = img.affine

            img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
            img.uncache()

            if self.image_only:
                continue

            if not compatible_meta:
                for meta_key in header:
                    meta_datum = header[meta_key]
                    # pytype: disable=attribute-error
                    if (type(meta_datum).__name__ == "ndarray"
                            and np_str_obj_array_pattern.search(
                                meta_datum.dtype.str) is not None):
                        continue
                    # pytype: enable=attribute-error
                    compatible_meta[meta_key] = meta_datum
            else:
                assert np.allclose(
                    header["affine"], compatible_meta["affine"]
                ), "affine data of all images should be same."

        img_array = np.stack(img_array,
                             axis=0) if len(img_array) > 1 else img_array[0]
        if self.image_only:
            return img_array
        return img_array, compatible_meta
Esempio n. 23
0
def create_rotate(spatial_dims, radians):
    """
    create a 2D or 3D rotation matrix

    Args:
        spatial_dims (2|3): spatial rank
        radians (float or a sequence of floats): rotation radians
        when spatial_dims == 3, the `radians` sequence corresponds to
        rotation in the 1st, 2nd, and 3rd dim respectively.
    """
    radians = ensure_tuple(radians)
    if spatial_dims == 2:
        if len(radians) >= 1:
            sin_, cos_ = np.sin(radians[0]), np.cos(radians[0])
            return np.array([[cos_, -sin_, 0.], [sin_, cos_, 0.], [0., 0.,
                                                                   1.]])

    if spatial_dims == 3:
        affine = None
        if len(radians) >= 1:
            sin_, cos_ = np.sin(radians[0]), np.cos(radians[0])
            affine = np.array([
                [1., 0., 0., 0.],
                [0., cos_, -sin_, 0.],
                [0., sin_, cos_, 0.],
                [0., 0., 0., 1.],
            ])
        if len(radians) >= 2:
            sin_, cos_ = np.sin(radians[1]), np.cos(radians[1])
            affine = affine @ np.array([
                [cos_, 0.0, sin_, 0.],
                [0., 1., 0., 0.],
                [-sin_, 0., cos_, 0.],
                [0., 0., 0., 1.],
            ])
        if len(radians) >= 3:
            sin_, cos_ = np.sin(radians[2]), np.cos(radians[2])
            affine = affine @ np.array([
                [cos_, -sin_, 0., 0.],
                [sin_, cos_, 0., 0.],
                [0., 0., 1., 0.],
                [0., 0., 0., 1.],
            ])
        return affine

    raise ValueError('create_rotate got spatial_dims={}, radians={}.'.format(
        spatial_dims, radians))
Esempio n. 24
0
    def __call__(self, img, mode: Optional[str] = None):
        spatial_shape = img.shape[1:]
        spatial_border = ensure_tuple(self.spatial_border)
        for b in spatial_border:
            if b < 0 or not isinstance(b, int):
                raise ValueError("spatial_border must be int number and can not be less than 0.")

        if len(spatial_border) == 1:
            data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in range(len(spatial_shape))]
        elif len(spatial_border) == len(spatial_shape):
            data_pad_width = [(spatial_border[i], spatial_border[i]) for i in range(len(spatial_shape))]
        elif len(spatial_border) == len(spatial_shape) * 2:
            data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))]
        else:
            raise ValueError("unsupported length of spatial_border definition.")

        return np.pad(img, [(0, 0)] + data_pad_width, mode=mode or self.mode)
Esempio n. 25
0
 def __init__(
     self,
     select_fn: Callable = lambda x: x > 0,
     channel_indexes: Optional[IndexSelection] = None,
     margin: int = 0,
 ):
     """
     Args:
         select_fn: function to select expected foreground, default is to select values > 0.
         channel_indexes: if defined, select foreground only on the specified channels
             of image. if None, select foreground on the whole image.
         margin: add margin to all dims of the bounding box.
     """
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(
         channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Esempio n. 26
0
 def __init__(self, keys: KeysCollection, times: int, names):
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         times: expected copy times, for example, if keys is "img", times is 3,
             it will add 3 copies of "img" data to the dictionary.
         names(str, list or tuple of str): the names coresponding to the newly copied data,
             the length should match `len(keys) x times`. for example, if keys is ["img", "seg"]
             and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"].
     """
     super().__init__(keys)
     if times < 1:
         raise ValueError("times must be greater than 0.")
     self.times = times
     names = ensure_tuple(names)
     if len(names) != (len(self.keys) * times):
         raise ValueError(
             "length of names does not match `len(keys) x times`.")
     self.names = names
Esempio n. 27
0
 def __init__(self,
              applied_labels,
              independent: bool = True,
              connectivity: Optional[int] = None):
     """
     Args:
         applied_labels (int, list or tuple of int): Labels for applying the connected component on.
             If only one channel. The pixel whose value is not in this list will remain unchanged.
             If the data is in one-hot format, this is used to determine what channels to apply.
         independent (bool): consider several labels as a whole or independent, default is `True`.
             Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
             you want this "independent" to be specified as False.
         connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
             Accepted values are ranging from  1 to input.ndim. If ``None``, a full
             connectivity of ``input.ndim`` is used.
     """
     super().__init__()
     self.applied_labels = ensure_tuple(applied_labels)
     self.independent = independent
     self.connectivity = connectivity
Esempio n. 28
0
 def __init__(self,
              keys,
              source_key,
              select_fn=lambda x: x > 0,
              channel_indexes=None,
              margin=0):
     """
     Args:
         keys (hashable items): keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         source_key (str): data source to generate the bounding box of foreground, can be image or label, etc.
         select_fn (Callable): function to select expected foreground, default is to select values > 0.
         channel_indexes (int, tuple or list): if defined, select foregound only on the specified channels
             of image. if None, select foreground on the whole image.
         margin (int): add margin to all dims of the bounding box.
     """
     super().__init__(keys)
     self.source_key = source_key
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(
         channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Esempio n. 29
0
 def __init__(
     self,
     keys: KeysCollection,
     source_key: str,
     select_fn: Callable = lambda x: x > 0,
     channel_indexes: Optional[IndexSelection] = None,
     margin: int = 0,
 ):
     """
     Args:
         keys: keys of the corresponding items to be transformed.
             See also: :py:class:`monai.transforms.compose.MapTransform`
         source_key: data source to generate the bounding box of foreground, can be image or label, etc.
         select_fn: function to select expected foreground, default is to select values > 0.
         channel_indexes: if defined, select foreground only on the specified channels
             of image. if None, select foreground on the whole image.
         margin: add margin to all dims of the bounding box.
     """
     super().__init__(keys)
     self.source_key = source_key
     self.select_fn = select_fn
     self.channel_indexes = ensure_tuple(
         channel_indexes) if channel_indexes is not None else None
     self.margin = margin
Esempio n. 30
0
def ckpt_export(
    net_id: Optional[str] = None,
    filepath: Optional[PathLike] = None,
    ckpt_file: Optional[str] = None,
    meta_file: Optional[Union[str, Sequence[str]]] = None,
    config_file: Optional[Union[str, Sequence[str]]] = None,
    key_in_ckpt: Optional[str] = None,
    args_file: Optional[str] = None,
    **override,
):
    """
    Export the model checkpoint to the given filepath with metadata and config included as JSON files.

    Typical usage examples:

    .. code-block:: bash

        python -m monai.bundle ckpt_export network --filepath <export path> --ckpt_file <checkpoint path> ...

    Args:
        net_id: ID name of the network component in the config, it must be `torch.nn.Module`.
        filepath: filepath to export, if filename has no extension it becomes `.ts`.
        ckpt_file: filepath of the model checkpoint to load.
        meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
        config_file: filepath of the config file to save in TorchScript model and extract network information,
            the saved key in the TorchScript model is the config filename without extension, and the saved config
            value is always serialized in JSON format no matter the original file format is JSON or YAML.
            it can be a single file or a list of files. if `None`, must be provided in `args_file`.
        key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
            weights. if not nested checkpoint, no need to set.
        args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
            `net_id` and override pairs. so that the command line inputs can be simplified.
        override: id-value pairs to override or add the corresponding config content.
            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.

    """
    _args = _update_args(
        args=args_file,
        net_id=net_id,
        filepath=filepath,
        meta_file=meta_file,
        config_file=config_file,
        ckpt_file=ckpt_file,
        key_in_ckpt=key_in_ckpt,
        **override,
    )
    _log_input_summary(tag="ckpt_export", args=_args)
    filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args(
        _args,
        "filepath",
        "ckpt_file",
        "config_file",
        net_id="",
        meta_file=None,
        key_in_ckpt="")

    parser = ConfigParser()

    parser.read_config(f=config_file_)
    if meta_file_ is not None:
        parser.read_meta(f=meta_file_)

    # the rest key-values in the _args are to override config content
    for k, v in _args.items():
        parser[k] = v

    net = parser.get_parsed_content(net_id_)
    if has_ignite:
        # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
        Checkpoint.load_objects(to_load={key_in_ckpt_: net},
                                checkpoint=ckpt_file_)
    else:
        copy_model_state(
            dst=net,
            src=ckpt_file_ if key_in_ckpt_ == "" else ckpt_file_[key_in_ckpt_])

    # convert to TorchScript model and save with meta data, config content
    net = convert_to_torchscript(model=net)

    extra_files: Dict = {}
    for i in ensure_tuple(config_file_):
        # split the filename and directory
        filename = os.path.basename(i)
        # remove extension
        filename, _ = os.path.splitext(filename)
        if filename in extra_files:
            raise ValueError(
                f"filename '{filename}' is given multiple times in config file list."
            )
        extra_files[filename] = json.dumps(
            ConfigParser.load_config_file(i)).encode()

    save_net_with_metadata(
        jit_obj=net,
        filename_prefix_or_stream=filepath_,
        include_config_vals=False,
        append_timestamp=False,
        meta_values=parser.get().pop("_meta_", None),
        more_extra_files=extra_files,
    )
    logger.info(f"exported to TorchScript file: {filepath_}.")