def __init__( self, keys: KeysCollection, range_x=0.0, range_y=0.0, range_z=0.0, prob: float = 0.1, keep_size: bool = True, interp_order: str = "bilinear", mode: str = "border", align_corners: bool = False, ): super().__init__(keys) self.range_x = ensure_tuple(range_x) if len(self.range_x) == 1: self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]])) self.range_y = ensure_tuple(range_y) if len(self.range_y) == 1: self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]])) self.range_z = ensure_tuple(range_z) if len(self.range_z) == 1: self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]])) self.prob = prob self.keep_size = keep_size self.interp_order = ensure_tuple_rep(interp_order, len(self.keys)) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.align_corners = align_corners self._do_transform = False self.x = 0.0 self.y = 0.0 self.z = 0.0
def select_labels(labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor], keep: NdarrayOrTensor) -> Union[Tuple, NdarrayOrTensor]: """ For element in labels, select indice keep from it. Args: labels: Sequence of array. Each element represents classification labels or scores corresponding to ``boxes``, sized (N,). keep: the indices to keep, same length with each element in labels. Return: selected labels, does not share memory with original labels. """ labels_tuple = ensure_tuple(labels, True) labels_select_list = [] keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0] for i in range(len(labels_tuple)): labels_t: torch.Tensor = convert_data_type(labels_tuple[i], torch.Tensor)[0] labels_t = labels_t[keep_t, ...] labels_select_list.append( convert_to_dst_type(src=labels_t, dst=labels_tuple[i])[0]) if isinstance(labels, (torch.Tensor, np.ndarray)): return labels_select_list[0] # type: ignore return tuple(labels_select_list)
def __call__(self, filename): """ Args: filename (str, list, tuple, file): path file or file-like object or a list of files. """ filename = ensure_tuple(filename) img_array = list() compatible_meta = None for name in filename: img = Image.open(name) data = np.asarray(img) if self.dtype: data = data.astype(self.dtype) img_array.append(data) meta = dict() meta['filename_or_obj'] = name meta['spatial_shape'] = data.shape[:2] meta['format'] = img.format meta['mode'] = img.mode meta['width'] = img.width meta['height'] = img.height meta['info'] = img.info if self.image_only: continue if not compatible_meta: compatible_meta = meta else: assert np.allclose(meta['spatial_shape'], compatible_meta['spatial_shape']), \ 'all the images in the list should have same spatial shape.' img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] return img_array if self.image_only else (img_array, compatible_meta)
def flip_boxes( boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int], flip_axes: Optional[Union[Sequence[int], int]] = None, ) -> NdarrayOrTensor: """ Flip boxes when the corresponding image is flipped Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` spatial_size: image spatial size. flip_axes: spatial axes along which to flip over. Default is None. The default `axis=None` will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis. If axis is a tuple of ints, flipping is performed on all of the axes specified in the tuple. Returns: flipped boxes, with same data type as ``boxes``, does not share memory with ``boxes`` """ spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size = ensure_tuple_rep(spatial_size, spatial_dims) if flip_axes is None: flip_axes = tuple(range(0, spatial_dims)) flip_axes = ensure_tuple(flip_axes) # flip box _flip_boxes = deepcopy(boxes) for axis in flip_axes: _flip_boxes[:, axis + spatial_dims] = spatial_size[axis] - boxes[:, axis] - TO_REMOVE _flip_boxes[:, axis] = spatial_size[axis] - boxes[:, axis + spatial_dims] - TO_REMOVE return _flip_boxes
def create_shear(spatial_dims, coefs): """ create a shearing matrix Args: spatial_dims (int): spatial rank coefs (floats): shearing factors, defaults to 0. """ coefs = list(ensure_tuple(coefs)) if spatial_dims == 2: while len(coefs) < 2: coefs.append(0.0) return np.array([ [1, coefs[0], 0.], [coefs[1], 1., 0.], [0., 0., 1.], ]) if spatial_dims == 3: while len(coefs) < 6: coefs.append(0.0) return np.array([ [1., coefs[0], coefs[1], 0.], [coefs[2], 1., coefs[3], 0.], [coefs[4], coefs[5], 1., 0.], [0., 0., 0., 1.], ]) raise NotImplementedError
def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, output_key='spacing'): """ Args: affine_key (hashable): the key to the original affine. The affine will be used to compute input data's pixdim. pixdim (sequence of floats): output voxel spacing. interp_order (int or sequence of ints): int: the same interpolation order for all data indexed by `self,keys`; sequence of ints, should correspond to an interpolation order for each data item indexed by `self.keys` respectively. keep_shape (bool): whether to maintain the original spatial shape after resampling. Defaults to False. output_key (hashable): key to be added to the output dictionary to track the pixdim status. """ MapTransform.__init__(self, keys) self.affine_key = affine_key self.spacing_transform = Spacing(pixdim, keep_shape=keep_shape) interp_order = ensure_tuple(interp_order) self.interp_order = interp_order \ if len(interp_order) == len(self.keys) else interp_order * len(self.keys) self.output_key = output_key
def __init__(self, pixdim, diagonal=False, mode='constant', cval=0, dtype=None): """ Args: pixdim (sequence of floats): output voxel spacing. diagonal (bool): whether to resample the input to have a diagonal affine matrix. If True, the input data is resampled to the following affine:: np.diag((pixdim_0, pixdim_1, ..., pixdim_n, 1)) This effectively resets the volume to the world coordinate system (RAS+ in nibabel). The original orientation, rotation, shearing are not preserved. If False, this transform preserves the axes orientation, orthogonal rotation and translation components from the original affine. This option will not flip/swap axes of the original data. mode (`reflect|constant|nearest|mirror|wrap`): The mode parameter determines how the input array is extended beyond its boundaries. cval (scalar): Value to fill past edges of input if mode is "constant". Default is 0.0. dtype (None or np.dtype): output array data type, defaults to None to use input data's dtype. """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal self.mode = mode self.cval = cval self.dtype = dtype
def randomize(self, img_size): self._size = [self.roi_size] * len(img_size) if not isinstance(self.roi_size, (list, tuple)) else self.roi_size if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
def __init__( self, keys, pixdim, diagonal=False, mode="nearest", cval=0, interp_order=3, dtype=None, meta_key_format="{}.{}" ): """ Args: pixdim (sequence of floats): output voxel spacing. diagonal (bool): whether to resample the input to have a diagonal affine matrix. If True, the input data is resampled to the following affine:: np.diag((pixdim_0, pixdim_1, pixdim_2, 1)) This effectively resets the volume to the world coordinate system (RAS+ in nibabel). The original orientation, rotation, shearing are not preserved. If False, the axes orientation, orthogonal rotation and translations components from the original affine will be preserved in the target affine. This option will not flip/swap axes against the original ones. mode (`reflect|constant|nearest|mirror|wrap`): The mode parameter determines how the input array is extended beyond its boundaries. Default is 'nearest'. cval (scalar): Value to fill past edges of input if mode is "constant". Default is 0.0. interp_order (int or sequence of ints): int: the same interpolation order for all data indexed by `self.keys`; sequence of ints, should correspond to an interpolation order for each data item indexed by `self.keys` respectively. dtype (None or np.dtype): output array data type, defaults to None to use input data's dtype. meta_key_format (str): key format to read/write affine matrices to the data dictionary. """ super().__init__(keys) self.spacing_transform = Spacing(pixdim, diagonal=diagonal, mode=mode, cval=cval, dtype=dtype) interp_order = ensure_tuple(interp_order) self.interp_order = interp_order if len(interp_order) == len(self.keys) else interp_order * len(self.keys) self.meta_key_format = meta_key_format
def __init__(self, keys): self.keys = ensure_tuple(keys) if not self.keys: raise ValueError("keys unspecified") for key in self.keys: if not isinstance(key, Hashable): raise ValueError(f"keys should be a hashable or a sequence of hashables, got {type(key)}")
def generate_spatial_bounding_box(img, select_fn=lambda x: x > 0, channel_indexes=None, margin=0): """ generate the spatial bounding box of foreground in the image with start-end positions. Users can define arbitrary function to select expected foreground from the whole image or specified channels. And it can also add margin to every dim of the bounding box. Args: img (ndarrary): source image to generate bounding box from. select_fn (Callable): function to select expected foreground, default is to select values > 0. channel_indexes (int, tuple or list): if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin (int): add margin to all dims of the bounding box. """ assert isinstance(margin, int), "margin must be int type." data = img[[*(ensure_tuple(channel_indexes)) ]] if channel_indexes is not None else img data = np.any(select_fn(data), axis=0) nonzero_idx = np.nonzero(data) box_start = list() box_end = list() for i in range(data.ndim): assert len(nonzero_idx[i] ) > 0, f"did not find nonzero index at spatial dim {i}" box_start.append(max(0, np.min(nonzero_idx[i]) - margin)) box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1)) return box_start, box_end
def randomize(self, img_size): self._size = ensure_tuple_rep(self.roi_size, len(img_size)) if self.random_size: self._size = [self.R.randint(low=self._size[i], high=img_size[i] + 1) for i in range(len(img_size))] if self.random_center: valid_size = get_valid_patch_size(img_size, self._size) self._slices = ensure_tuple(slice(None)) + get_random_patch(img_size, valid_size, self.R)
def __init__( self, rotate_range=None, shear_range=None, translate_range=None, scale_range=None, as_tensor_output=True, device=None, ): """ Args: rotate_range (a sequence of positive floats): rotate_range[0] with be used to generate the 1st rotation parameter from `uniform[-rotate_range[0], rotate_range[0])`. Similarly, `rotate_range[2]` and `rotate_range[3]` are used in 3D affine for the range of 2nd and 3rd axes. shear_range (a sequence of positive floats): shear_range[0] with be used to generate the 1st shearing parameter from `uniform[-shear_range[0], shear_range[0])`. Similarly, `shear_range[1]` to `shear_range[N]` controls the range of the uniform distribution used to generate the 2nd to N-th parameter. translate_range (a sequence of positive floats): translate_range[0] with be used to generate the 1st shift parameter from `uniform[-translate_range[0], translate_range[0])`. Similarly, `translate_range[1]` to `translate_range[N]` controls the range of the uniform distribution used to generate the 2nd to N-th parameter. scale_range (a sequence of positive floats): scaling_range[0] with be used to generate the 1st scaling factor from `uniform[-scale_range[0], scale_range[0]) + 1.0`. Similarly, `scale_range[1]` to `scale_range[N]` controls the range of the uniform distribution used to generate the 2nd to N-th parameter. See also: - :py:meth:`monai.transforms.utils.create_rotate` - :py:meth:`monai.transforms.utils.create_shear` - :py:meth:`monai.transforms.utils.create_translate` - :py:meth:`monai.transforms.utils.create_scale` """ self.rotate_range = ensure_tuple(rotate_range) self.shear_range = ensure_tuple(shear_range) self.translate_range = ensure_tuple(translate_range) self.scale_range = ensure_tuple(scale_range) self.rotate_params = None self.shear_params = None self.translate_params = None self.scale_params = None self.as_tensor_output = as_tensor_output self.device = device
def __init__(self, spatial_size, method: str = "symmetric", mode: str = "constant"): self.spatial_size = ensure_tuple(spatial_size) assert method in ("symmetric", "end"), "unsupported padding type." self.method = method assert isinstance(mode, str), "mode must be str." self.mode = mode
def test_value(self, input, expected_value, wrap_array=False): result = ensure_tuple(input, wrap_array) self.assertTrue(isinstance(result, tuple)) if isinstance(input, (np.ndarray, torch.Tensor)): for i, j in zip(result, expected_value): assert_allclose(i, j) else: self.assertTupleEqual(result, expected_value)
def rot90_boxes(boxes: NdarrayOrTensor, spatial_size: Union[Sequence[int], int], k: int = 1, axes: Tuple[int, int] = (0, 1)): """ Rotate boxes by 90 degrees in the plane specified by axes. Rotation direction is from the first towards the second axis. Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` spatial_size: image spatial size. k : number of times the array is rotated by 90 degrees. axes: (2,) array_like The array is rotated in the plane defined by the axes. Axes must be different. Returns: A rotated view of `boxes`. Notes: ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is the reverse of ``rot90_boxes(boxes, spatial_size, k=1, axes=(0,1))`` ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is equivalent to ``rot90_boxes(boxes, spatial_size, k=-1, axes=(0,1))`` """ spatial_dims: int = get_spatial_dims(boxes=boxes) spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims)) axes = ensure_tuple(axes) # type: ignore if len(axes) != 2: raise ValueError("len(axes) must be 2.") if axes[0] == axes[1] or abs(axes[0] - axes[1]) == spatial_dims: raise ValueError("Axes must be different.") if axes[0] >= spatial_dims or axes[0] < -spatial_dims or axes[ 1] >= spatial_dims or axes[1] < -spatial_dims: raise ValueError( f"Axes={axes} out of range for array of ndim={spatial_dims}.") k %= 4 if k == 0: return boxes if k == 2: return flip_boxes(flip_boxes(boxes, spatial_size_, axes[0]), spatial_size_, axes[1]) if k == 1: boxes_ = flip_boxes(boxes, spatial_size_, axes[1]) return swapaxes_boxes(boxes_, axes[0], axes[1]) else: # k == 3 boxes_ = swapaxes_boxes(boxes, axes[0], axes[1]) spatial_size_[axes[0]], spatial_size_[axes[1]] = spatial_size_[ axes[1]], spatial_size_[axes[0]] return flip_boxes(boxes_, spatial_size_, axes[1])
def __init__(self, select_fn=lambda x: x > 0, channel_indexes=None, margin=0): """ Args: select_fn (Callable): function to select expected foreground, default is to select values > 0. channel_indexes (int, tuple or list): if defined, select foregound only on the specified channels of image. if None, select foreground on the whole image. margin (int): add margin to all dims of the bounding box. """ self.select_fn = select_fn self.channel_indexes = ensure_tuple(channel_indexes) if channel_indexes is not None else None self.margin = margin
def __init__(self, data, transform=None): """ Args: data (Iterable): input data to load and transform to generate dataset for model. transform (Callable, optional): transforms to execute operations on input data. """ self.data = data if isinstance(transform, Compose): self.transform = transform else: self.transform = Compose(ensure_tuple(transform))
def create_scale(spatial_dims, scaling_factor): """ create a scaling matrix Args: spatial_dims (int): spatial rank scaling_factor (floats): scaling factors, defaults to 1. """ scaling_factor = list(ensure_tuple(scaling_factor)) while len(scaling_factor) < spatial_dims: scaling_factor.append(1.0) return np.diag(scaling_factor[:spatial_dims] + [1.0])
def create_translate(spatial_dims, shift): """ create a translation matrix Args: spatial_dims (int): spatial rank shift (floats): translate factors, defaults to 0. """ shift = ensure_tuple(shift) affine = np.eye(spatial_dims + 1) for i, a in enumerate(shift[:spatial_dims]): affine[i, spatial_dims] = a return affine
def test_numpy_values(self, keys, times, names): input_data = { "img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]]) } result = CopyItemsd(keys=keys, times=times, names=names)(input_data) for name in ensure_tuple(names): self.assertTrue(name in result) result[name] += 1 np.testing.assert_allclose(result[name], np.array([[1, 2], [2, 3]])) np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]]))
def __call__(self, filename): """ Args: filename (str, list, tuple, file): path file or file-like object or a list of files. """ filename = ensure_tuple(filename) img_array = list() compatible_meta = dict() for name in filename: img = nib.load(name) img = correct_nifti_header_if_necessary(img) header = dict(img.header) header["filename_or_obj"] = name header["affine"] = img.affine header["original_affine"] = img.affine.copy() header["as_closest_canonical"] = self.as_closest_canonical ndim = img.header["dim"][0] spatial_rank = min(ndim, 3) header["spatial_shape"] = img.header["dim"][1:spatial_rank + 1] if self.as_closest_canonical: img = nib.as_closest_canonical(img) header["affine"] = img.affine img_array.append(np.array(img.get_fdata(dtype=self.dtype))) img.uncache() if self.image_only: continue if not compatible_meta: for meta_key in header: meta_datum = header[meta_key] # pytype: disable=attribute-error if (type(meta_datum).__name__ == "ndarray" and np_str_obj_array_pattern.search( meta_datum.dtype.str) is not None): continue # pytype: enable=attribute-error compatible_meta[meta_key] = meta_datum else: assert np.allclose( header["affine"], compatible_meta["affine"] ), "affine data of all images should be same." img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0] if self.image_only: return img_array return img_array, compatible_meta
def create_rotate(spatial_dims, radians): """ create a 2D or 3D rotation matrix Args: spatial_dims (2|3): spatial rank radians (float or a sequence of floats): rotation radians when spatial_dims == 3, the `radians` sequence corresponds to rotation in the 1st, 2nd, and 3rd dim respectively. """ radians = ensure_tuple(radians) if spatial_dims == 2: if len(radians) >= 1: sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) return np.array([[cos_, -sin_, 0.], [sin_, cos_, 0.], [0., 0., 1.]]) if spatial_dims == 3: affine = None if len(radians) >= 1: sin_, cos_ = np.sin(radians[0]), np.cos(radians[0]) affine = np.array([ [1., 0., 0., 0.], [0., cos_, -sin_, 0.], [0., sin_, cos_, 0.], [0., 0., 0., 1.], ]) if len(radians) >= 2: sin_, cos_ = np.sin(radians[1]), np.cos(radians[1]) affine = affine @ np.array([ [cos_, 0.0, sin_, 0.], [0., 1., 0., 0.], [-sin_, 0., cos_, 0.], [0., 0., 0., 1.], ]) if len(radians) >= 3: sin_, cos_ = np.sin(radians[2]), np.cos(radians[2]) affine = affine @ np.array([ [cos_, -sin_, 0., 0.], [sin_, cos_, 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.], ]) return affine raise ValueError('create_rotate got spatial_dims={}, radians={}.'.format( spatial_dims, radians))
def __call__(self, img, mode: Optional[str] = None): spatial_shape = img.shape[1:] spatial_border = ensure_tuple(self.spatial_border) for b in spatial_border: if b < 0 or not isinstance(b, int): raise ValueError("spatial_border must be int number and can not be less than 0.") if len(spatial_border) == 1: data_pad_width = [(spatial_border[0], spatial_border[0]) for _ in range(len(spatial_shape))] elif len(spatial_border) == len(spatial_shape): data_pad_width = [(spatial_border[i], spatial_border[i]) for i in range(len(spatial_shape))] elif len(spatial_border) == len(spatial_shape) * 2: data_pad_width = [(spatial_border[2 * i], spatial_border[2 * i + 1]) for i in range(len(spatial_shape))] else: raise ValueError("unsupported length of spatial_border definition.") return np.pad(img, [(0, 0)] + data_pad_width, mode=mode or self.mode)
def __init__( self, select_fn: Callable = lambda x: x > 0, channel_indexes: Optional[IndexSelection] = None, margin: int = 0, ): """ Args: select_fn: function to select expected foreground, default is to select values > 0. channel_indexes: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin to all dims of the bounding box. """ self.select_fn = select_fn self.channel_indexes = ensure_tuple( channel_indexes) if channel_indexes is not None else None self.margin = margin
def __init__(self, keys: KeysCollection, times: int, names): """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` times: expected copy times, for example, if keys is "img", times is 3, it will add 3 copies of "img" data to the dictionary. names(str, list or tuple of str): the names coresponding to the newly copied data, the length should match `len(keys) x times`. for example, if keys is ["img", "seg"] and times is 2, names can be: ["img_1", "seg_1", "img_2", "seg_2"]. """ super().__init__(keys) if times < 1: raise ValueError("times must be greater than 0.") self.times = times names = ensure_tuple(names) if len(names) != (len(self.keys) * times): raise ValueError( "length of names does not match `len(keys) x times`.") self.names = names
def __init__(self, applied_labels, independent: bool = True, connectivity: Optional[int] = None): """ Args: applied_labels (int, list or tuple of int): Labels for applying the connected component on. If only one channel. The pixel whose value is not in this list will remain unchanged. If the data is in one-hot format, this is used to determine what channels to apply. independent (bool): consider several labels as a whole or independent, default is `True`. Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case you want this "independent" to be specified as False. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ super().__init__() self.applied_labels = ensure_tuple(applied_labels) self.independent = independent self.connectivity = connectivity
def __init__(self, keys, source_key, select_fn=lambda x: x > 0, channel_indexes=None, margin=0): """ Args: keys (hashable items): keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` source_key (str): data source to generate the bounding box of foreground, can be image or label, etc. select_fn (Callable): function to select expected foreground, default is to select values > 0. channel_indexes (int, tuple or list): if defined, select foregound only on the specified channels of image. if None, select foreground on the whole image. margin (int): add margin to all dims of the bounding box. """ super().__init__(keys) self.source_key = source_key self.select_fn = select_fn self.channel_indexes = ensure_tuple( channel_indexes) if channel_indexes is not None else None self.margin = margin
def __init__( self, keys: KeysCollection, source_key: str, select_fn: Callable = lambda x: x > 0, channel_indexes: Optional[IndexSelection] = None, margin: int = 0, ): """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` source_key: data source to generate the bounding box of foreground, can be image or label, etc. select_fn: function to select expected foreground, default is to select values > 0. channel_indexes: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. margin: add margin to all dims of the bounding box. """ super().__init__(keys) self.source_key = source_key self.select_fn = select_fn self.channel_indexes = ensure_tuple( channel_indexes) if channel_indexes is not None else None self.margin = margin
def ckpt_export( net_id: Optional[str] = None, filepath: Optional[PathLike] = None, ckpt_file: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, config_file: Optional[Union[str, Sequence[str]]] = None, key_in_ckpt: Optional[str] = None, args_file: Optional[str] = None, **override, ): """ Export the model checkpoint to the given filepath with metadata and config included as JSON files. Typical usage examples: .. code-block:: bash python -m monai.bundle ckpt_export network --filepath <export path> --ckpt_file <checkpoint path> ... Args: net_id: ID name of the network component in the config, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension it becomes `.ts`. ckpt_file: filepath of the model checkpoint to load. meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. config_file: filepath of the config file to save in TorchScript model and extract network information, the saved key in the TorchScript model is the config filename without extension, and the saved config value is always serialized in JSON format no matter the original file format is JSON or YAML. it can be a single file or a list of files. if `None`, must be provided in `args_file`. key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model weights. if not nested checkpoint, no need to set. args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, `net_id` and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. """ _args = _update_args( args=args_file, net_id=net_id, filepath=filepath, meta_file=meta_file, config_file=config_file, ckpt_file=ckpt_file, key_in_ckpt=key_in_ckpt, **override, ) _log_input_summary(tag="ckpt_export", args=_args) filepath_, ckpt_file_, config_file_, net_id_, meta_file_, key_in_ckpt_ = _pop_args( _args, "filepath", "ckpt_file", "config_file", net_id="", meta_file=None, key_in_ckpt="") parser = ConfigParser() parser.read_config(f=config_file_) if meta_file_ is not None: parser.read_meta(f=meta_file_) # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v net = parser.get_parsed_content(net_id_) if has_ignite: # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) else: copy_model_state( dst=net, src=ckpt_file_ if key_in_ckpt_ == "" else ckpt_file_[key_in_ckpt_]) # convert to TorchScript model and save with meta data, config content net = convert_to_torchscript(model=net) extra_files: Dict = {} for i in ensure_tuple(config_file_): # split the filename and directory filename = os.path.basename(i) # remove extension filename, _ = os.path.splitext(filename) if filename in extra_files: raise ValueError( f"filename '{filename}' is given multiple times in config file list." ) extra_files[filename] = json.dumps( ConfigParser.load_config_file(i)).encode() save_net_with_metadata( jit_obj=net, filename_prefix_or_stream=filepath_, include_config_vals=False, append_timestamp=False, meta_values=parser.get().pop("_meta_", None), more_extra_files=extra_files, ) logger.info(f"exported to TorchScript file: {filepath_}.")