def get_nonlinearity(non_type):
  if non_type == 'ReLU':
    return ME.MinkowskiReLU()
  elif non_type == 'ELU':
    # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension)
    return ME.MinkowskiELU()
  else:
    raise ValueError(f'Type {non_type}, not defined')
Ejemplo n.º 2
0
  def network_initialization(self, in_channels, config, D):
    up_kernel_size = 3
    self.conv_up1 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[0], in_channels[0], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[0]),
        ME.MinkowskiELU())

    self.conv_up2 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[1], in_channels[0], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[0]),
        ME.MinkowskiELU())

    self.conv_up3 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[2], in_channels[1], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[1]),
        ME.MinkowskiELU())

    self.conv_up4 = nn.Sequential(
        ME.MinkowskiConvolutionTranspose(
            in_channels[3], in_channels[2], kernel_size=up_kernel_size, stride=2,
            generate_new_coords=True, dimension=3),
        ME.MinkowskiBatchNorm(in_channels[2]),
        ME.MinkowskiELU())

    self.conv_feat1 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[0], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat2 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[1], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat3 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[2], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat4 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[3], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())
Ejemplo n.º 3
0
  def network_initialization(self, in_channels, config, D):
    self.conv_feat1 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[0], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat2 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[1], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat3 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[2], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())

    self.conv_feat4 = nn.Sequential(
        ME.MinkowskiConvolution(
            in_channels[3], config.upsample_feat_size, kernel_size=1, dimension=3),
        ME.MinkowskiBatchNorm(config.upsample_feat_size),
        ME.MinkowskiELU())
Ejemplo n.º 4
0
 def network_initialization(self, in_channels, config, D):
   self.conv1 = ME.MinkowskiConvolution(
       in_channels, config.proposal_feat_size, kernel_size=1, dimension=3)
   self.bn1 = ME.MinkowskiInstanceNorm(config.proposal_feat_size)
   self.conv2 = ME.MinkowskiConvolution(
       config.proposal_feat_size, config.proposal_feat_size, kernel_size=1, dimension=3)
   self.bn2 = ME.MinkowskiInstanceNorm(config.proposal_feat_size)
   self.final_class_logits = ME.MinkowskiConvolution(
       config.proposal_feat_size, self.out_channels * 2, kernel_size=1, dimension=3, has_bias=True)
   self.final_bbox = ME.MinkowskiConvolution(
       config.proposal_feat_size, self.out_channels * 6, kernel_size=1, dimension=3, has_bias=True)
   self.elu = ME.MinkowskiELU()
   self.softmax = ME.MinkowskiSoftmax()
   if self.is_rotation_bbox:
     self.final_rotation = ME.MinkowskiConvolution(
         config.proposal_feat_size, self.out_channels * self.rotation_criterion.NUM_OUTPUT,
         kernel_size=1, dimension=3, has_bias=True)
Ejemplo n.º 5
0
    def __init__(self, resolution, in_nchannel=512):
        nn.Module.__init__(self)

        self.resolution = resolution

        # Input sparse tensor must have tensor stride 128.
        enc_ch = self.ENC_CHANNELS
        dec_ch = self.DEC_CHANNELS

        # Encoder
        self.enc_block_s1 = nn.Sequential(
            ME.MinkowskiConvolution(1,
                                    enc_ch[0],
                                    kernel_size=3,
                                    stride=1,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[0]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s1s2 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[0],
                                    enc_ch[1],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[1],
                                    enc_ch[1],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[1]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s2s4 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[1],
                                    enc_ch[2],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[2],
                                    enc_ch[2],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[2]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s4s8 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[2],
                                    enc_ch[3],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[3],
                                    enc_ch[3],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[3]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s8s16 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[3],
                                    enc_ch[4],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[4],
                                    enc_ch[4],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[4]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s16s32 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[4],
                                    enc_ch[5],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[5],
                                    enc_ch[5],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[5]),
            ME.MinkowskiELU(),
        )

        self.enc_block_s32s64 = nn.Sequential(
            ME.MinkowskiConvolution(enc_ch[5],
                                    enc_ch[6],
                                    kernel_size=2,
                                    stride=2,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(enc_ch[6],
                                    enc_ch[6],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(enc_ch[6]),
            ME.MinkowskiELU(),
        )

        # Decoder
        self.dec_block_s64s32 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                enc_ch[6],
                dec_ch[5],
                kernel_size=4,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[5],
                                    dec_ch[5],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[5]),
            ME.MinkowskiELU(),
        )

        self.dec_s32_cls = ME.MinkowskiConvolution(dec_ch[5],
                                                   1,
                                                   kernel_size=1,
                                                   bias=True,
                                                   dimension=3)

        self.dec_block_s32s16 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                enc_ch[5],
                dec_ch[4],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[4],
                                    dec_ch[4],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[4]),
            ME.MinkowskiELU(),
        )

        self.dec_s16_cls = ME.MinkowskiConvolution(dec_ch[4],
                                                   1,
                                                   kernel_size=1,
                                                   bias=True,
                                                   dimension=3)

        self.dec_block_s16s8 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[4],
                dec_ch[3],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[3],
                                    dec_ch[3],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[3]),
            ME.MinkowskiELU(),
        )

        self.dec_s8_cls = ME.MinkowskiConvolution(dec_ch[3],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s8s4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[3],
                dec_ch[2],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[2],
                                    dec_ch[2],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[2]),
            ME.MinkowskiELU(),
        )

        self.dec_s4_cls = ME.MinkowskiConvolution(dec_ch[2],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s4s2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[2],
                dec_ch[1],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[1],
                                    dec_ch[1],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[1]),
            ME.MinkowskiELU(),
        )

        self.dec_s2_cls = ME.MinkowskiConvolution(dec_ch[1],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        self.dec_block_s2s1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                dec_ch[1],
                dec_ch[0],
                kernel_size=2,
                stride=2,
                dimension=3,
            ),
            ME.MinkowskiBatchNorm(dec_ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(dec_ch[0],
                                    dec_ch[0],
                                    kernel_size=3,
                                    dimension=3),
            ME.MinkowskiBatchNorm(dec_ch[0]),
            ME.MinkowskiELU(),
        )

        self.dec_s1_cls = ME.MinkowskiConvolution(dec_ch[0],
                                                  1,
                                                  kernel_size=1,
                                                  bias=True,
                                                  dimension=3)

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 6
0
    def __init__(self):
        nn.Module.__init__(self)

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[0], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[1], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block1_cls = ME.MinkowskiConvolution(
            ch[1], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 2
        self.block2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[1], ch[2], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block2_cls = ME.MinkowskiConvolution(
            ch[2], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 3
        self.block3 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[2], ch[3], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block3_cls = ME.MinkowskiConvolution(
            ch[3], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 4
        self.block4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[3], ch[4], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block4_cls = ME.MinkowskiConvolution(
            ch[4], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 5
        self.block5 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[4], ch[5], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block5_cls = ME.MinkowskiConvolution(
            ch[5], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 6
        self.block6 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[5], ch[6], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.block6_cls = ME.MinkowskiConvolution(
            ch[6], 1, kernel_size=1, bias=True, dimension=3
        )

        # pruning
        self.pruning = ME.MinkowskiPruning()
Ejemplo n.º 7
0
    def __init__(self):
        nn.Module.__init__(self)

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiConvolution(1, ch[0], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
        )

        self.block2 = nn.Sequential(
            ME.MinkowskiConvolution(ch[0], ch[1], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block3 = nn.Sequential(
            ME.MinkowskiConvolution(ch[1], ch[2], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block4 = nn.Sequential(
            ME.MinkowskiConvolution(ch[2], ch[3], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block5 = nn.Sequential(
            ME.MinkowskiConvolution(ch[3], ch[4], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block6 = nn.Sequential(
            ME.MinkowskiConvolution(ch[4], ch[5], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block7 = nn.Sequential(
            ME.MinkowskiConvolution(ch[5], ch[6], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.global_pool = ME.MinkowskiGlobalPooling()

        self.linear_mean = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
        self.linear_log_var = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
        self.weight_initialization()
Ejemplo n.º 8
0
    def __init__(self, config={}, **kwargs):
        MEEncoder.__init__(self, config=config, **kwargs)

        # need square and power of 2 image size input
        power = math.log(self.config.input_size[0], 2)
        assert (power % 1 == 0.0) and (
            power > 3
        ), "Dumoulin Encoder needs a power of 2 as image input size (>=16)"
        # need square image input
        assert torch.all(
            torch.tensor([
                self.config.input_size[i] == self.config.input_size[0]
                for i in range(1, len(self.config.input_size))
            ])), "Dumoulin Encoder needs a square image input size"

        assert self.config.n_conv_layers == power, "The number of convolutional layers in DumoulinEncoder must be log2(input_size) "

        # network architecture
        if self.config.hidden_channels is None:
            self.config.hidden_channels = 8

        hidden_channels = self.config.hidden_channels
        kernels_size = [4, 4] * self.config.n_conv_layers
        strides = [1, 2] * self.config.n_conv_layers
        pads = [0, 1] * self.config.n_conv_layers
        dils = [1, 1] * self.config.n_conv_layers

        # feature map size
        feature_map_sizes = conv_output_sizes(self.config.input_size,
                                              2 * self.config.n_conv_layers,
                                              kernels_size, strides, pads,
                                              dils)

        # local feature
        ## convolutional layers
        self.local_feature_shape = (
            int(hidden_channels * math.pow(2, self.config.feature_layer + 1)),
            feature_map_sizes[2 * self.config.feature_layer + 1][0],
            feature_map_sizes[2 * self.config.feature_layer + 1][1])
        self.lf = nn.Sequential()

        for conv_layer_id in range(self.config.feature_layer + 1):
            if conv_layer_id == 0:
                self.lf.add_module(
                    "conv_{}".format(conv_layer_id),
                    nn.Sequential(
                        ME.MinkowskiConvolution(
                            self.config.n_channels,
                            hidden_channels,
                            kernel_size=kernels_size[2 * conv_layer_id],
                            stride=strides[2 * conv_layer_id],
                            dilation=dils[2 * conv_layer_id],
                            dimension=self.spatial_dims,
                            bias=True),
                        ME.MinkowskiBatchNorm(hidden_channels),
                        ME.MinkowskiELU(inplace=True),
                        ME.MinkowskiConvolution(
                            hidden_channels,
                            2 * hidden_channels,
                            kernel_size=kernels_size[2 * conv_layer_id + 1],
                            stride=strides[2 * conv_layer_id + 1],
                            dilation=dils[2 * conv_layer_id + 1],
                            dimension=self.spatial_dims,
                            bias=True),
                        ME.MinkowskiBatchNorm(2 * hidden_channels),
                        ME.MinkowskiELU(inplace=True),
                    ))
            else:
                self.lf.add_module(
                    "conv_{}".format(conv_layer_id),
                    nn.Sequential(
                        ME.MinkowskiConvolution(
                            hidden_channels,
                            hidden_channels,
                            kernel_size=kernels_size[2 * conv_layer_id],
                            stride=strides[2 * conv_layer_id],
                            dilation=dils[2 * conv_layer_id],
                            dimension=self.spatial_dims,
                            bias=True),
                        ME.MinkowskiBatchNorm(hidden_channels),
                        ME.MinkowskiELU(inplace=True),
                        ME.MinkowskiConvolution(
                            hidden_channels,
                            2 * hidden_channels,
                            kernel_size=kernels_size[2 * conv_layer_id + 1],
                            stride=strides[2 * conv_layer_id + 1],
                            dilation=dils[2 * conv_layer_id + 1],
                            dimension=self.spatial_dims,
                            bias=True),
                        ME.MinkowskiBatchNorm(2 * hidden_channels),
                        ME.MinkowskiELU(inplace=True),
                    ))
            hidden_channels *= 2
        self.lf.out_connection_type = ("conv", hidden_channels)

        # global feature
        self.gf = nn.Sequential()
        ## convolutional layers
        for conv_layer_id in range(self.config.feature_layer + 1,
                                   self.config.n_conv_layers):
            self.gf.add_module(
                "conv_{}".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiConvolution(
                        hidden_channels,
                        hidden_channels,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=True),
                    ME.MinkowskiBatchNorm(hidden_channels),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels,
                        2 * hidden_channels,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=True),
                    ME.MinkowskiBatchNorm(2 * hidden_channels),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels *= 2
        self.gf.out_connection_type = ("conv", hidden_channels)

        # encoding feature
        if self.config.encoder_conditional_type == "gaussian":
            self.add_module(
                "ef",
                nn.Sequential(
                    ME.MinkowskiConvolution(hidden_channels,
                                            hidden_channels,
                                            kernel_size=1,
                                            stride=1,
                                            dilation=1,
                                            dimension=self.spatial_dims,
                                            bias=True),
                    ME.MinkowskiBatchNorm(hidden_channels),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(hidden_channels,
                                            2 * self.config.n_latents,
                                            kernel_size=1,
                                            stride=1,
                                            dilation=1,
                                            dimension=self.spatial_dims,
                                            bias=True),
                ))
        elif self.config.encoder_conditional_type == "deterministic":
            self.add_module(
                "ef",
                nn.Sequential(
                    ME.MinkowskiConvolution(hidden_channels,
                                            hidden_channels,
                                            kernel_size=1,
                                            stride=1,
                                            dilation=1,
                                            dimension=self.spatial_dims,
                                            bias=True),
                    ME.MinkowskiBatchNorm(hidden_channels),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(hidden_channels,
                                            self.config.n_latents,
                                            kernel_size=1,
                                            stride=1,
                                            dilation=1,
                                            dimension=self.spatial_dims,
                                            bias=True),
                ))

        # global pool
        self.global_pool = ME.MinkowskiGlobalPooling()

        # attention feature
        if self.config.use_attention:
            self.add_module(
                "af",
                ME.MinkowskiConvolution(hidden_channels,
                                        4 * self.config.n_latents,
                                        kernel_size=1,
                                        stride=1,
                                        dilation=1,
                                        dimension=self.spatial_dims,
                                        bias=True))
Ejemplo n.º 9
0
    def __init__(self, config={}, **kwargs):
        MEDecoder.__init__(self, config=config, **kwargs)

        # need square and power of 2 image size input
        power = math.log(self.config.input_size[0], 2)
        assert (power % 1 == 0.0) and (
            power > 3
        ), "Dumoulin Encoder needs a power of 2 as image input size (>=16)"
        assert self.config.input_size[0] == self.config.input_size[
            1], "Dumoulin Encoder needs a square image input size"

        assert self.config.n_conv_layers == power, "The number of convolutional layers in DumoulinEncoder must be log2(input_size) "

        # network architecture
        kernels_size = [3, 2] * self.config.n_conv_layers
        strides = [1, 2] * self.config.n_conv_layers
        dils = [1, 1] * self.config.n_conv_layers

        if self.config.hidden_channels is None:
            self.config.hidden_channels = 8

        # encoder feature inverse
        hidden_channels = int(self.config.hidden_channels *
                              math.pow(2, self.config.n_conv_layers))
        self.efi = nn.Sequential(
            ME.MinkowskiConvolution(
                self.config.n_latents,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
            ME.MinkowskiConvolution(
                hidden_channels,
                hidden_channels,
                kernel_size=1,
                stride=1,  #dilation=1,
                dimension=self.spatial_dims,
                bias=False),
            ME.MinkowskiBatchNorm(hidden_channels),
            ME.MinkowskiELU(inplace=True),
        )
        self.efi.out_connection_type = ("conv", hidden_channels)
        self.efi_cls = ME.MinkowskiConvolution(hidden_channels,
                                               1,
                                               kernel_size=1,
                                               bias=True,
                                               dimension=self.spatial_dims)

        # global feature inverse
        self.gfi = nn.Sequential()
        ## convolutional layers
        for conv_layer_id in range(self.config.n_conv_layers - 1,
                                   self.config.feature_layer + 1 - 1, -1):
            self.gfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.gfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.gfi.out_connection_type = ("conv", hidden_channels)
        #self.gfi_cls = ME.MinkowskiConvolution(hidden_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # local feature inverse
        self.lfi = nn.Sequential()
        for conv_layer_id in range(self.config.feature_layer + 1 - 1, 0, -1):
            self.lfi.add_module(
                "conv_{}_i".format(conv_layer_id),
                nn.Sequential(
                    ME.MinkowskiGenerativeConvolutionTranspose(
                        hidden_channels,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id + 1],
                        stride=strides[2 * conv_layer_id + 1],
                        #dilation=dils[2 * conv_layer_id + 1],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                    ME.MinkowskiConvolution(
                        hidden_channels // 2,
                        hidden_channels // 2,
                        kernel_size=kernels_size[2 * conv_layer_id],
                        stride=strides[2 * conv_layer_id],
                        #dilation=dils[2 * conv_layer_id],
                        dimension=self.spatial_dims,
                        bias=False),
                    ME.MinkowskiBatchNorm(hidden_channels // 2),
                    ME.MinkowskiELU(inplace=True),
                ))
            hidden_channels = hidden_channels // 2

            self.lfi.add_module(
                "cls_{}_i".format(conv_layer_id),
                ME.MinkowskiConvolution(hidden_channels,
                                        1,
                                        kernel_size=1,
                                        bias=True,
                                        dimension=self.spatial_dims))

        self.lfi.add_module(
            "conv_0_i",
            nn.Sequential(
                ME.MinkowskiGenerativeConvolutionTranspose(
                    hidden_channels,
                    hidden_channels // 2,
                    kernel_size=kernels_size[1],
                    stride=strides[1],  #dilation=dils[1],
                    dimension=self.spatial_dims,
                    bias=False),
                ME.MinkowskiBatchNorm(hidden_channels // 2),
                ME.MinkowskiELU(inplace=True),
                ME.MinkowskiConvolution(
                    hidden_channels // 2,
                    self.config.n_channels,
                    kernel_size=kernels_size[0],
                    stride=strides[0],  #dilation=dils[0],
                    dimension=self.spatial_dims,
                    bias=False),
            ))
        self.lfi.add_module(
            "cls_0_i",
            ME.MinkowskiConvolution(self.config.n_channels,
                                    1,
                                    kernel_size=1,
                                    bias=False,
                                    dimension=self.spatial_dims))

        self.lfi.out_connection_type = ("conv", self.config.n_channels)
        #self.lfi_cls = ME.MinkowskiConvolution(self.config.n_channels, 1, kernel_size=1, bias=True, dimension=self.spatial_dims)

        # pruning
        self.pruning = ME.MinkowskiPruning()