def _make_up_layers(self): up_layers, up_samples = nn.ModuleList(), nn.ModuleList() upsample_mode, blocks_up, spatial_dims, filters, norm = ( self.upsample_mode, self.blocks_up, self.spatial_dims, self.init_filters, self.norm, ) n_up = len(blocks_up) for i in range(n_up): sample_in_channels = filters * 2**(n_up - i) up_layers.append( nn.Sequential(*[ ResBlock(spatial_dims, sample_in_channels // 2, norm=norm) for _ in range(blocks_up[i]) ])) up_samples.append( nn.Sequential(*[ get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1), get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode), ])) return up_layers, up_samples
def _prepare_vae_modules(self): zoom = 2**(len(self.blocks_down) - 1) v_filters = self.init_filters * zoom total_elements = int(self.smallest_filters * np.prod(self.fc_insize)) self.vae_down = nn.Sequential( get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer(self.spatial_dims, self.smallest_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) self.vae_fc3 = nn.Linear(self.vae_nz, total_elements) self.vae_fc_up_sample = nn.Sequential( get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, )