def _prepare_vae_modules(self): zoom = 2**(len(self.blocks_down) - 1) v_filters = self.init_filters * zoom total_elements = int(self.smallest_filters * np.prod(self.fc_insize)) self.vae_down = nn.Sequential( get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True), get_norm_layer(self.spatial_dims, self.smallest_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, ) self.vae_fc1 = nn.Linear(total_elements, self.vae_nz) self.vae_fc2 = nn.Linear(total_elements, self.vae_nz) self.vae_fc3 = nn.Linear(self.vae_nz, total_elements) self.vae_fc_up_sample = nn.Sequential( get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1), get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode), get_norm_layer(self.spatial_dims, v_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, )
def _make_final_conv(self, out_channels: int): return nn.Sequential( get_norm_layer(self.spatial_dims, self.init_filters, norm_name=self.norm_name, num_groups=self.num_groups), self.relu, get_conv_layer(self.spatial_dims, self.init_filters, out_channels=out_channels, kernel_size=1, bias=True), )