Example #1
0
    def __init__(self, in_size, out_size, is_batchnorm):
        super(UnetGatingSignal3, self).__init__()
        self.fmap_size = (4, 4, 4)

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.InstanceNorm3d(in_size // 2),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Example #2
0
    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode,
                 sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.gate_block_2 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.combine_gates = nn.Sequential(
            nn.Conv3d(in_size * 2, in_size, kernel_size=1,
                      stride=1, padding=0), nn.BatchNorm3d(in_size),
            nn.ReLU(inplace=True))

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock3D') != -1:
                continue
            init_weights(m, init_type='kaiming')
Example #3
0
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 kernel_size=(3, 3, 1),
                 padding_size=(1, 1, 0),
                 init_stride=(1, 1, 1)):
        super(UnetConv3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.InstanceNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.InstanceNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Example #4
0
    def __init__(self, in_size, out_size, is_batchnorm=True):
        super(UnetUp3_CT, self).__init__()
        self.conv = UnetConv3(in_size + out_size,
                              out_size,
                              is_batchnorm,
                              kernel_size=(3, 3, 3),
                              padding_size=(1, 1, 1))
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
Example #5
0
    def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
        super(unet_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(
            3, 3, 3), padding_size=(1, 1, 1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(
            3, 3, 3), padding_size=(1, 1, 1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(
            3, 3, 3), padding_size=(1, 1, 1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(
            3, 3, 3), padding_size=(1, 1, 1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(
            3, 3, 3), padding_size=(1, 1, 1))

        # upsampling
        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        self.dropout1 = nn.Dropout(p=0.3)
        self.dropout2 = nn.Dropout(p=0.3)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
Example #6
0
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size,
                                         out_size,
                                         kernel_size=4,
                                         stride=2,
                                         padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')
Example #7
0
    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
        super(UnetUp3, self).__init__()
        if is_deconv:
            self.conv = UnetConv3(in_size, out_size, is_batchnorm)
            self.up = nn.ConvTranspose3d(in_size,
                                         out_size,
                                         kernel_size=(4, 4, 1),
                                         stride=(2, 2, 1),
                                         padding=(1, 1, 0))
        else:
            self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm)
            self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
Example #8
0
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 n=2,
                 ks=3,
                 stride=1,
                 padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.BatchNorm2d(out_size),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Example #9
0
    def __init__(self,
                 in_size,
                 out_size,
                 kernel_size=(1, 1, 1),
                 is_batchnorm=True):
        super(UnetGridGatingSignal3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1),
                          (0, 0, 0)),
                nn.InstanceNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1),
                          (0, 0, 0)),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Example #10
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 nonlocal_mode='concatenation',
                 attention_dsample=(2, 2, 2),
                 is_batchnorm=True):
        super(Attention_UNet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels,
                               filters[0],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv2 = UnetConv3(filters[0],
                               filters[1],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv3 = UnetConv3(filters[1],
                               filters[2],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv4 = UnetConv3(filters[2],
                               filters[3],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.center = UnetConv3(filters[3],
                                filters[4],
                                self.is_batchnorm,
                                kernel_size=(3, 3, 3),
                                padding_size=(1, 1, 1))
        self.gating = UnetGridGatingSignal3(filters[4],
                                            filters[4],
                                            kernel_size=(1, 1, 1),
                                            is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = MultiAttentionBlock(
            in_size=filters[1],
            gate_size=filters[2],
            inter_size=filters[1],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(
            in_size=filters[2],
            gate_size=filters[3],
            inter_size=filters[2],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(
            in_size=filters[3],
            gate_size=filters[4],
            inter_size=filters[3],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)

        # upsampling
        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv3(in_size=filters[3],
                             out_size=n_classes,
                             scale_factor=8)
        self.dsv3 = UnetDsv3(in_size=filters[2],
                             out_size=n_classes,
                             scale_factor=4)
        self.dsv2 = UnetDsv3(in_size=filters[1],
                             out_size=n_classes,
                             scale_factor=2)
        self.dsv1 = nn.Conv3d(in_channels=filters[0],
                              out_channels=n_classes,
                              kernel_size=1)

        # final conv (without any concat)
        self.final = nn.Conv3d(n_classes * 4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
Example #11
0
    def __init__(self,
                 in_channels,
                 gating_channels,
                 inter_channels=None,
                 dimension=3,
                 mode='concatenation',
                 sub_sample_factor=(2, 2, 2)):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2, 3]
        assert mode in [
            'concatenation', 'concatenation_debug', 'concatenation_residual'
        ]

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple):
            self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list):
            self.sub_sample_factor = tuple(sub_sample_factor)
        else:
            self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # Output transform
        self.W = nn.Sequential(
            conv_nd(in_channels=self.in_channels,
                    out_channels=self.in_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(in_channels=self.in_channels,
                             out_channels=self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size,
                             stride=self.sub_sample_factor,
                             padding=0,
                             bias=False)
        self.phi = conv_nd(in_channels=self.gating_channels,
                           out_channels=self.inter_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           bias=True)
        self.psi = conv_nd(in_channels=self.inter_channels,
                           out_channels=1,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           bias=True)

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')

        # Define the operation
        if mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concatenation_debug':
            self.operation_function = self._concatenation_debug
        elif mode == 'concatenation_residual':
            self.operation_function = self._concatenation_residual
        else:
            raise NotImplementedError('Unknown operation function.')
Example #12
0
    def __init__(self,
                 in_channels,
                 gating_channels,
                 inter_channels=None,
                 dimension=3,
                 mode='concatenation',
                 sub_sample_factor=(1, 1, 1),
                 bn_layer=True,
                 use_W=True,
                 use_phi=True,
                 use_theta=True,
                 use_psi=True,
                 nonlinearity1='relu'):
        super(_GridAttentionBlockND_TORR, self).__init__()

        assert dimension in [2, 3]
        assert mode in [
            'concatenation', 'concatenation_softmax', 'concatenation_sigmoid',
            'concatenation_mean', 'concatenation_range_normalise',
            'concatenation_mean_flow'
        ]

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_factor = sub_sample_factor if isinstance(
            sub_sample_factor,
            tuple) else tuple([sub_sample_factor]) * dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplemented

        # initialise id functions
        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.W = lambda x: x
        self.theta = lambda x: x
        self.psi = lambda x: x
        self.phi = lambda x: x
        self.nl1 = lambda x: x

        if use_W:
            if bn_layer:
                self.W = nn.Sequential(
                    conv_nd(in_channels=self.in_channels,
                            out_channels=self.in_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0),
                    bn(self.in_channels),
                )
            else:
                self.W = conv_nd(in_channels=self.in_channels,
                                 out_channels=self.in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        if use_theta:
            self.theta = conv_nd(in_channels=self.in_channels,
                                 out_channels=self.inter_channels,
                                 kernel_size=self.sub_sample_kernel_size,
                                 stride=self.sub_sample_factor,
                                 padding=0,
                                 bias=False)

        if use_phi:
            self.phi = conv_nd(in_channels=self.gating_channels,
                               out_channels=self.inter_channels,
                               kernel_size=self.sub_sample_kernel_size,
                               stride=self.sub_sample_factor,
                               padding=0,
                               bias=False)

        if use_psi:
            self.psi = conv_nd(in_channels=self.inter_channels,
                               out_channels=1,
                               kernel_size=1,
                               stride=1,
                               padding=0,
                               bias=True)

        if nonlinearity1:
            if nonlinearity1 == 'relu':
                self.nl1 = lambda x: F.relu(x, inplace=True)

        if 'concatenation' in mode:
            self.operation_function = self._concatenation
        else:
            raise NotImplementedError('Unknown operation function.')

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')

        if use_psi and self.mode == 'concatenation_sigmoid':
            nn.init.constant(self.psi.bias.data, 3.0)

        if use_psi and self.mode == 'concatenation_softmax':
            nn.init.constant(self.psi.bias.data, 10.0)

        # if use_psi and self.mode == 'concatenation_mean':
        #     nn.init.constant(self.psi.bias.data, 3.0)

        # if use_psi and self.mode == 'concatenation_range_normalise':
        #     nn.init.constant(self.psi.bias.data, 3.0)

        parallel = False
        if parallel:
            if use_W: self.W = nn.DataParallel(self.W)
            if use_phi: self.phi = nn.DataParallel(self.phi)
            if use_psi: self.psi = nn.DataParallel(self.psi)
            if use_theta: self.theta = nn.DataParallel(self.theta)