Example #1
0
    def __init__(self,
                 in_channels,
                 n_filters,
                 batch_norm: bool = True,
                 group_norm=0):
        super().__init__()

        self.in_dim = in_channels
        self.num_filters = n_filters
        activation = nn.ReLU
        self.batch_norm = batch_norm

        # Down sampling
        self.down_1 = conv_block_2_3d(self.in_dim, self.num_filters,
                                      activation, self.batch_norm, group_norm)
        self.pool_1 = max_pooling_3d()
        self.down_2 = conv_block_2_3d(self.num_filters, self.num_filters,
                                      activation, self.batch_norm, group_norm)
        self.pool_2 = max_pooling_3d()
        self.down_3 = conv_block_2_3d(self.num_filters, self.num_filters * 2,
                                      activation, self.batch_norm, group_norm)
        self.pool_3 = max_pooling_3d()
        self.down_4 = conv_block_2_3d(self.num_filters * 2,
                                      self.num_filters * 4, activation,
                                      self.batch_norm, group_norm)
        self.pool_4 = max_pooling_3d()
        # self.down_5 = conv_block_2_3d(self.num_filters, self.num_filters, activation, self.batch_norm, group_norm)
        # self.pool_5 = max_pooling_3d()
        self.bridge = conv_block_2_3d(self.num_filters * 4,
                                      self.num_filters * 8, activation,
                                      self.batch_norm, group_norm)
Example #2
0
 def __init__(self, group_norm):
     super().__init__()
     activation = torch.nn.ReLU
     self.conv = torch.nn.Sequential(
         # torch.nn.Conv3d(256, 128, kernel_size=1),
         # nn.GroupNorm(group_norm, 128),
         # activation(),
         torch.nn.Conv3d(128, 64, kernel_size=1),
         nn.GroupNorm(group_norm, 64),
         activation(),
         max_pooling_3d(),
         torch.nn.Conv3d(64, 32, kernel_size=1),
         nn.GroupNorm(group_norm, 32),
         activation(),
         torch.nn.Conv3d(32, 8, kernel_size=1),
         nn.GroupNorm(group_norm, 8),
         activation(),
         # torch.nn.Conv3d(16, 8, kernel_size=1),
         # nn.BatchNorm3d(8),
         # activation(),
     )
     self.regress = torch.nn.Sequential(
         torch.nn.Linear(400, 200),
         # nn.BatchNorm1d(200),
         nn.GroupNorm(group_norm, 200),
         activation(),
         torch.nn.Linear(200, 96),
         # nn.BatchNorm1d(100),
         nn.GroupNorm(group_norm, 96),
         activation(),
         torch.nn.Linear(96, 12),
         torch.nn.Tanh()  # affine grid seems to want [-1, 1]
     )
     bias = torch.from_numpy(np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
                                       0])).float()
     self.regress[-2].weight.data.zero_()
     self.regress[-2].bias.data.copy_(bias)
Example #3
0
    def __init__(self, num_filters, group_norm: int):
        super().__init__()
        activation = torch.nn.ReLU
        self.down1 = max_pooling_3d()
        self.down_conv1 = conv_block_2_3d(6,
                                          num_filters,
                                          activation,
                                          group_norm=group_norm)
        self.down2 = max_pooling_3d()
        self.down_conv2 = conv_block_2_3d(num_filters,
                                          num_filters,
                                          activation,
                                          group_norm=group_norm)
        self.down3 = max_pooling_3d()
        self.down_conv3 = conv_block_2_3d(num_filters,
                                          num_filters * 2,
                                          activation,
                                          group_norm=group_norm)

        self.affine_down = max_pooling_3d()
        self.affine_conv = conv_block_2_3d(num_filters * 2,
                                           num_filters * 2,
                                           activation,
                                           group_norm=group_norm)
        self.affine_down2 = max_pooling_3d()
        self.affine_conv2 = conv_block_2_3d(num_filters * 2,
                                            num_filters * 2,
                                            activation,
                                            group_norm=group_norm)

        self.affine_regressor = torch.nn.Sequential(
            torch.nn.Linear(2048, 400),
            # nn.BatchNorm1d(200),
            nn.GroupNorm(group_norm, 400),
            activation(),
            torch.nn.Linear(400, 12),
            torch.nn.Tanh()  # affine grid seems to want [-1, 1]
        )
        bias = torch.from_numpy(np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
                                          0])).float()
        self.affine_regressor[-2].weight.data.zero_()
        self.affine_regressor[-2].bias.data.copy_(bias)

        # self.up0 = nn.Upsample(scale_factor=2, mode='nearest')
        self.up1 = nn.Upsample(scale_factor=4, mode='nearest')
        self.up_conv1 = conv_block_2_3d(num_filters * 4,
                                        num_filters * 2,
                                        activation,
                                        group_norm=group_norm)
        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.up_conv2 = conv_block_2_3d(num_filters * 3,
                                        num_filters,
                                        activation,
                                        group_norm=group_norm)
        self.up3 = nn.Upsample(scale_factor=2, mode="nearest")
        self.up_conv3 = conv_block_2_3d(num_filters * 2,
                                        num_filters,
                                        activation,
                                        group_norm=group_norm)
        self.flow = nn.Conv3d(num_filters, 3, kernel_size=3, padding=1)
        # init flow layer with small weights and bias
        self.flow.weight = nn.Parameter(
            Normal(0, 1e-5).sample(self.flow.weight.shape))
        self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape))