def __init__( self, sizes: Sequence[Sequence[int]] = ((20, 30, 40),), aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),), indexing: str = "ij", ) -> None: super().__init__() if not issequenceiterable(sizes[0]): self.sizes = tuple((s,) for s in sizes) else: self.sizes = ensure_tuple(sizes) if not issequenceiterable(aspect_ratios[0]): aspect_ratios = (aspect_ratios,) * len(self.sizes) if len(self.sizes) != len(aspect_ratios): raise ValueError( "len(sizes) and len(aspect_ratios) should be equal. \ It represents the number of feature maps." ) spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1 spatial_dims = look_up_option(spatial_dims, [2, 3]) self.spatial_dims = spatial_dims self.indexing = look_up_option(indexing, ["ij", "xy"]) self.aspect_ratios = aspect_ratios self.cell_anchors = [ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios) ]
def __call__(self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img device = device if device is not None else self.device field = self.sfield() dgrid = self.grid + field.to(self.grid_dtype) dgrid = moveaxis(dgrid, 1, -1) # type: ignore img_t = convert_to_tensor(img[None], torch.float32, device) out = grid_sample( input=img_t, grid=dgrid, mode=look_up_option(self.grid_mode, GridSampleMode), align_corners=self.grid_align_corners, padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode), ) out_t, *_ = convert_to_dst_type(out.squeeze(0), img) return out_t
def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16. """ model_urls = { "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", } model_url = look_up_option(arch, model_urls, None) if model_url is None: raise ValueError( "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." ) pattern = re.compile( r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) state_dict[new_key] = state_dict[key] del state_dict[key] model_dict = model.state_dict() state_dict = { k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) } model_dict.update(state_dict) model.load_state_dict(model_dict)
def __call__(self, randomize=False) -> torch.Tensor: if randomize: self.randomize() field = self.field.clone() if self.spatial_zoom is not None: resized_field = interpolate( input=field, scale_factor=self.spatial_zoom, mode=look_up_option(self.mode, InterpolateMode), align_corners=self.align_corners, recompute_scale_factor=False, ) mina = resized_field.min() maxa = resized_field.max() minv = self.field.min() maxv = self.field.max() # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina) field = norm_field.mul_(maxv - minv).add_(minv) return field
def _load_state_dict(model: nn.Module, arch: str, progress: bool): """ This function is used to load pretrained models. """ model_url = look_up_option(arch, SE_NET_MODELS, None) if model_url is None: raise ValueError( "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', " + "and se_resnext101_32x4d are supported to load pretrained weights." ) pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") pattern_down_conv = re.compile( r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") if isinstance(model_url, dict): download_url(model_url["url"], filepath=model_url["filename"]) state_dict = torch.load(model_url["filename"], map_location=None) else: state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): new_key = None if pattern_conv.match(key): new_key = re.sub(pattern_conv, r"\1conv.\2", key) elif pattern_bn.match(key): new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) elif pattern_se.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) elif pattern_se2.match(key): state_dict[key] = state_dict[key].squeeze() new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) elif pattern_down_conv.match(key): new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) elif pattern_down_bn.match(key): new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) if new_key: state_dict[new_key] = state_dict[key] del state_dict[key] model_dict = model.state_dict() state_dict = { k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) } model_dict.update(state_dict) model.load_state_dict(model_dict)
def __init__( self, feature_map_scales: Union[Sequence[int], Sequence[float]] = (1, 2, 4, 8), base_anchor_shapes: Union[Sequence[Sequence[int]], Sequence[Sequence[float]]] = ( (32, 32, 32), (48, 20, 20), (20, 48, 20), (20, 20, 48), ), indexing: str = "ij", ) -> None: nn.Module.__init__(self) spatial_dims = len(base_anchor_shapes[0]) spatial_dims = look_up_option(spatial_dims, [2, 3]) self.spatial_dims = spatial_dims self.indexing = look_up_option(indexing, ["ij", "xy"]) base_anchor_shapes_t = torch.Tensor(base_anchor_shapes) self.cell_anchors = [self.generate_anchors_using_shape(s * base_anchor_shapes_t) for s in feature_map_scales]
def __init__( self, device: torch.device, val_data_loader: Iterable | DataLoader, epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Callable[[Engine, Any], Any] | None = None, postprocessing: Transform | None = None, key_val_metric: dict[str, Metric] | None = None, additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, val_handlers: Sequence | None = None, amp: bool = False, mode: ForwardMode | str = ForwardMode.EVAL, event_names: list[str | EventEnum] | None = None, event_to_attr: dict | None = None, decollate: bool = True, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, max_epochs=1, data_loader=val_data_loader, epoch_length=epoch_length, non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, postprocessing=postprocessing, key_metric=key_val_metric, additional_metrics=additional_metrics, metric_cmp_fn=metric_cmp_fn, handlers=val_handlers, amp=amp, event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: self.mode = eval_mode elif mode == ForwardMode.TRAIN: self.mode = train_mode else: raise ValueError( f"unsupported mode: {mode}, should be 'eval' or 'train'.")
def __init__( self, kernel_type: str = "gaussian", num_bins: int = 23, sigma_ratio: float = 0.5, reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: """ Args: kernel_type: {``"gaussian"``, ``"b-spline"``} ``"gaussian"``: adapted from DeepReg Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1. ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK References: [1] "Nonrigid multimodality image registration" D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620. [2] "PET-CT Image Registration in the Chest Using Free-form Deformations" D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank IEEE Transactions in Medical Imaging. Vol.22, No.1, January 2003. pp.120-128. num_bins: number of bins for intensity sigma_ratio: a hyper param for gaussian function reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. """ super().__init__(reduction=LossReduction(reduction).value) if num_bins <= 0: raise ValueError("num_bins must > 0, got {num_bins}") bin_centers = torch.linspace(0.0, 1.0, num_bins) # (num_bins,) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"]) self.num_bins = num_bins self.kernel_type = kernel_type if self.kernel_type == "gaussian": self.preterm = 1 / (2 * sigma**2) self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr)
def __init__( self, spatial_dims: int = 3, kernel_size: int = 3, kernel_type: str = "rectangular", reduction: Union[LossReduction, str] = LossReduction.MEAN, smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, ndim: Optional[int] = None, ) -> None: """ Args: spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3. kernel_size: kernel spatial size, must be odd. kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid nan. smooth_dr: a small constant added to the denominator to avoid nan. .. deprecated:: 0.6.0 ``ndim`` is deprecated, use ``spatial_dims``. """ super().__init__(reduction=LossReduction(reduction).value) if ndim is not None: spatial_dims = ndim self.ndim = spatial_dims if self.ndim not in {1, 2, 3}: raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported") self.kernel_size = kernel_size if self.kernel_size % 2 == 0: raise ValueError(f"kernel_size must be odd, got {self.kernel_size}") _kernel = look_up_option(kernel_type, kernel_dict) self.kernel = _kernel(self.kernel_size) self.kernel_vol = self.get_kernel_vol() self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr)
def _make_layer( self, block: Type[Union[ResNetBlock, ResNetBottleneck]], planes: int, blocks: int, spatial_dims: int, shortcut_type: str, stride: int = 1, ) -> nn.Sequential: conv_type: Callable = Conv[Conv.CONV, spatial_dims] norm_type: Callable = Norm[Norm.BATCH, spatial_dims] downsample: Union[nn.Module, partial, None] = None if stride != 1 or self.in_planes != planes * block.expansion: if look_up_option(shortcut_type, {"A", "B"}) == "A": downsample = partial( self._downsample_basic_block, planes=planes * block.expansion, stride=stride, spatial_dims=spatial_dims, ) else: downsample = nn.Sequential( conv_type(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride), norm_type(planes * block.expansion), ) layers = [ block(in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample) ] self.in_planes = planes * block.expansion for _i in range(1, blocks): layers.append( block(self.in_planes, planes, spatial_dims=spatial_dims)) return nn.Sequential(*layers)
def __init__( self, device: torch.device, val_data_loader: Union[Iterable, DataLoader], epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, metric_cmp_fn: Callable = default_metric_cmp_fn, val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, decollate: bool = True, ) -> None: super().__init__( device=device, max_epochs=1, data_loader=val_data_loader, epoch_length=epoch_length, non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, postprocessing=postprocessing, key_metric=key_val_metric, additional_metrics=additional_metrics, metric_cmp_fn=metric_cmp_fn, handlers=val_handlers, amp=amp, event_names=event_names, event_to_attr=event_to_attr, decollate=decollate, ) self.mode = look_up_option(mode, ForwardMode) if mode == ForwardMode.EVAL: self.mode = eval_mode elif mode == ForwardMode.TRAIN: self.mode = train_mode else: raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
def _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool) -> None: if adv_prop: arch = arch.split("efficientnet-")[-1] + "-ap" model_url = look_up_option(arch, url_map, None) if model_url is None: print(f"pretrained weights of {arch} is not provided") else: # load state dict from url model_url = url_map[arch] pretrain_state_dict = model_zoo.load_url(model_url, progress=progress) model_state_dict = model.state_dict() pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)") for key, value in model_state_dict.items(): pretrain_key = re.sub(pattern, r"\1\2", key) if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[ pretrain_key].shape: model_state_dict[key] = pretrain_state_dict[pretrain_key] model.load_state_dict(model_state_dict)
def __init__( self, in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[Sequence[int], int], hidden_size: int, num_heads: int, pos_embed: str, dropout_rate: float = 0.0, spatial_dims: int = 3, ) -> None: """ Args: in_channels: dimension of input channels. img_size: dimension of input image. patch_size: dimension of patch size. hidden_size: dimension of hidden layer. num_heads: number of attention heads. pos_embed: position embedding layer type. dropout_rate: faction of the input units to drop. spatial_dims: number of spatial dimensions. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) for m, p in zip(img_size, patch_size): if m < p: raise ValueError("patch_size should be smaller than img_size.") if self.pos_embed == "perceptron" and m % p != 0: raise ValueError( "patch_size should be divisible by img_size for perceptron." ) self.n_patches = np.prod( [im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) self.patch_dim = int(in_channels * np.prod(patch_size)) self.patch_embeddings: nn.Module if self.pos_embed == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size) elif self.pos_embed == "perceptron": # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)} self.patch_embeddings = nn.Sequential( Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)) self.position_embeddings = nn.Parameter( torch.zeros(1, self.n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate) trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) self.apply(self._init_weights)
def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them dtype = np.dtype(dtype) if isinstance(dtype, (type, str)) else dtype return look_up_option(dtype, _np_to_torch_dtype)
def dtype_torch_to_numpy(dtype): """Convert a torch dtype to its numpy equivalent.""" return look_up_option(dtype, _torch_to_np_dtype)