示例#1
0
文件: segresnet.py 项目: lsho76/MONAI
    def _prepare_vae_modules(self):
        zoom = 2**(len(self.blocks_down) - 1)
        v_filters = self.init_filters * zoom
        total_elements = int(self.smallest_filters * np.prod(self.fc_insize))

        self.vae_down = nn.Sequential(
            get_norm_layer(self.spatial_dims,
                           v_filters,
                           norm_name=self.norm_name,
                           num_groups=self.num_groups),
            self.relu,
            get_conv_layer(self.spatial_dims,
                           v_filters,
                           self.smallest_filters,
                           stride=2,
                           bias=True),
            get_norm_layer(self.spatial_dims,
                           self.smallest_filters,
                           norm_name=self.norm_name,
                           num_groups=self.num_groups),
            self.relu,
        )
        self.vae_fc1 = nn.Linear(total_elements, self.vae_nz)
        self.vae_fc2 = nn.Linear(total_elements, self.vae_nz)
        self.vae_fc3 = nn.Linear(self.vae_nz, total_elements)

        self.vae_fc_up_sample = nn.Sequential(
            get_conv_layer(self.spatial_dims,
                           self.smallest_filters,
                           v_filters,
                           kernel_size=1),
            get_upsample_layer(self.spatial_dims,
                               v_filters,
                               upsample_mode=self.upsample_mode),
            get_norm_layer(self.spatial_dims,
                           v_filters,
                           norm_name=self.norm_name,
                           num_groups=self.num_groups),
            self.relu,
        )
示例#2
0
文件: segresnet.py 项目: lsho76/MONAI
 def _make_final_conv(self, out_channels: int):
     return nn.Sequential(
         get_norm_layer(self.spatial_dims,
                        self.init_filters,
                        norm_name=self.norm_name,
                        num_groups=self.num_groups),
         self.relu,
         get_conv_layer(self.spatial_dims,
                        self.init_filters,
                        out_channels=out_channels,
                        kernel_size=1,
                        bias=True),
     )