def _make_up_layers(self): up_layers, up_samples = nn.ModuleList(), nn.ModuleList() upsample_mode, blocks_up, spatial_dims, filters, norm = ( self.upsample_mode, self.blocks_up, self.spatial_dims, self.init_filters, self.norm, ) n_up = len(blocks_up) for i in range(n_up): sample_in_channels = filters * 2**(n_up - i) up_layers.append( nn.Sequential(*[ ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i]) ])) up_samples.append( nn.Sequential(*[ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), ])) return up_layers, up_samples
def _make_down_layers(self): down_layers = nn.ModuleList() blocks_down, spatial_dims, filters, norm_name, num_groups = ( self.blocks_down, self.spatial_dims, self.init_filters, self.norm_name, self.num_groups, ) for i in range(len(blocks_down)): layer_in_channels = filters * 2**i pre_conv = (get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2) if i > 0 else nn.Identity()) down_layer = nn.Sequential( pre_conv, *[ ResBlock(spatial_dims, layer_in_channels, norm_name=norm_name, num_groups=num_groups) for _ in range(blocks_down[i]) ], ) down_layers.append(down_layer) return down_layers
def __init__( self, spatial_dims: int = 3, init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, dropout_prob: Optional[float] = None, norm_name: str = "group", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): super().__init__() assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3." self.spatial_dims = spatial_dims self.init_filters = init_filters self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob self.norm_name = norm_name self.num_groups = num_groups self.upsample_mode = UpsampleMode(upsample_mode) self.use_conv_final = use_conv_final self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) self.down_layers = self._make_down_layers() self.up_layers, self.up_samples = self._make_up_layers() self.relu = Act[Act.RELU](inplace=True) self.conv_final = self._make_final_conv(out_channels) if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters), self.act, get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True), )
def _prepare_vae_modules(self): zoom = 2**(len(self.blocks_down) - 1) v_filters = self.init_filters * zoom total_elements = int(self.smallest_filters * np.prod(self.fc_insize)) self.vae_down = nn.Sequential( get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer(self.spatial_dims, self.smallest_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) self.vae_fc3 = nn.Linear(self.vae_nz, total_elements) self.vae_fc_up_sample = nn.Sequential( get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, )
def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(self.spatial_dims, self.init_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, self.init_filters, out_channels=out_channels, kernel_size=1, bias=True), )
def __init__( self, spatial_dims: int = 3, init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, dropout_prob: Optional[float] = None, act: Union[Tuple, str] = ("RELU", { "inplace": True }), norm: Union[Tuple, str] = ("GROUP", { "num_groups": 8 }), norm_name: str = "", num_groups: int = 8, use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): super().__init__() if spatial_dims not in (2, 3): raise ValueError("`spatial_dims` can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters self.in_channels = in_channels self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob self.act = act # input options self.act_mod = get_act_layer(act) if norm_name: if norm_name.lower() != "group": raise ValueError( f"Deprecating option 'norm_name={norm_name}', please use 'norm' instead." ) norm = ("group", {"num_groups": num_groups}) self.norm = norm self.upsample_mode = UpsampleMode(upsample_mode) self.use_conv_final = use_conv_final self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) self.down_layers = self._make_down_layers() self.up_layers, self.up_samples = self._make_up_layers() self.conv_final = self._make_final_conv(out_channels) if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)
def __init__( self, spatial_dims: int = 3, init_filters: int = 8, in_channels: int = 1, out_channels: int = 2, dropout_prob: Optional[float] = None, act: Union[Tuple, str] = ("RELU", { "inplace": True }), norm: Union[Tuple, str] = ("GROUP", { "num_groups": 8 }), use_conv_final: bool = True, blocks_down: tuple = (1, 2, 2, 4), blocks_up: tuple = (1, 1, 1), upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE, ): super().__init__() if spatial_dims not in (2, 3): raise AssertionError("spatial_dims can only be 2 or 3.") self.spatial_dims = spatial_dims self.init_filters = init_filters self.in_channels = in_channels self.blocks_down = blocks_down self.blocks_up = blocks_up self.dropout_prob = dropout_prob self.act = get_act_layer(act) self.norm = norm self.upsample_mode = UpsampleMode(upsample_mode) self.use_conv_final = use_conv_final self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters) self.down_layers = self._make_down_layers() self.up_layers, self.up_samples = self._make_up_layers() self.conv_final = self._make_final_conv(out_channels) if dropout_prob is not None: self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)