Ejemplo n.º 1
0
    def __init__(self, resolution, in_nchannel=512):
        nn.Module.__init__(self)

        self.resolution = resolution

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                in_nchannel, ch[0], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[1], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block1_cls = ME.MinkowskiConvolution(
            ch[1], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 2
        self.block2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[1], ch[2], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block2_cls = ME.MinkowskiConvolution(
            ch[2], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 3
        self.block3 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[2], ch[3], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block3_cls = ME.MinkowskiConvolution(
            ch[3], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 4
        self.block4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[3], ch[4], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block4_cls = ME.MinkowskiConvolution(
            ch[4], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 5
        self.block5 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[4], ch[5], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block5_cls = ME.MinkowskiConvolution(
            ch[5], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 6
        self.block6 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[5], ch[6], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.block6_cls = ME.MinkowskiConvolution(
            ch[6], 1, kernel_size=1, bias=True, dimension=3
        )

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 2
0
    def __init__(self, config={}, **kwargs):
        MEDecoder.__init__(self, config=config, **kwargs)

        # need square and power of 2 image size input
        power = math.log(self.config.input_size[0], 2)
        assert (power % 1 == 0.0) and (
            power > 3
        ), "Dumoulin Encoder needs a power of 2 as image input size (>=16)"
        assert self.config.input_size[0] == self.config.input_size[
            1], "Dumoulin Encoder needs a square image input size"

        assert self.config.n_conv_layers == power, "The number of convolutional layers in DumoulinEncoder must be log2(input_size) "

        # network architecture
        kernels_size = [3, 2] * self.config.n_conv_layers
        strides = [1, 2] * self.config.n_conv_layers
        dils = [1, 1] * self.config.n_conv_layers

        if self.config.hidden_channels is None:
            self.config.hidden_channels = 8

        # encoder feature inverse
        hidden_channels = int(self.config.hidden_channels *
                              math.pow(2, self.config.n_conv_layers))
        self.efi = nn.Sequential(
            ME.MinkowskiConvolution(
                self.config.n_latents,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
            ME.MinkowskiConvolution(
                hidden_channels,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
        )
        self.efi.out_connection_type = ("conv", hidden_channels)
        self.efi_cls = ME.MinkowskiConvolution(hidden_channels,
                                               1,
                                               kernel_size=1,
                                               bias=True,
                                               dimension=self.spatial_dims)

        # global feature inverse
        self.gfi = nn.Sequential()
        ## convolutional layers
        for conv_layer_id in range(self.config.n_conv_layers - 1,
                                   self.config.feature_layer + 1 - 1, -1):
            self.gfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.gfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.gfi.out_connection_type = ("conv", hidden_channels)
        #self.gfi_cls = ME.MinkowskiConvolution(hidden_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # local feature inverse
        self.lfi = nn.Sequential()
        for conv_layer_id in range(self.config.feature_layer + 1 - 1, 0, -1):
            self.lfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.lfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.lfi.add_module(
            "conv_0_i",
            nn.Sequential(
                ME.MinkowskiGenerativeConvolutionTranspose(
                    hidden_channels,
                    hidden_channels // 2,
                    kernel_size=kernels_size[1],
                    stride=strides[1],  #dilation=dils[1],
                    dimension=self.spatial_dims,
                    bias=False),
                ME.MinkowskiBatchNorm(hidden_channels // 2),
                ME.MinkowskiELU(inplace=True),
                ME.MinkowskiConvolution(
                    hidden_channels // 2,
                    self.config.n_channels,
                    kernel_size=kernels_size[0],
                    stride=strides[0],  #dilation=dils[0],
                    dimension=self.spatial_dims,
                    bias=False),
            ))
        self.lfi.add_module(
            "cls_0_i",
            ME.MinkowskiConvolution(self.config.n_channels,
                                    1,
                                    kernel_size=1,
                                    bias=False,
                                    dimension=self.spatial_dims))

        self.lfi.out_connection_type = ("conv", self.config.n_channels)
        #self.lfi_cls = ME.MinkowskiConvolution(self.config.n_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # pruning
        self.pruning = ME.MinkowskiPruning()