def __init__(self, in_size, gate_size, inter_size, nonlocal_mode,
                 sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.gate_block_2 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.combine_gates = nn.Sequential(
            nn.Conv3d(in_size * 2, in_size, kernel_size=1,
                      stride=1, padding=0), nn.BatchNorm3d(in_size),
            nn.ReLU(inplace=True))

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock3D') != -1:
                continue
            init_weights(m, init_type='kaiming')
Beispiel #2
0
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 kernel_size=(3, 3, 1),
                 padding_size=(1, 1, 0),
                 init_stride=(1, 1, 1)):
        super(UnetConv3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Beispiel #3
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True,
                 nonlocal_mode='embedded_gaussian',
                 nonlocal_sf=4):
        super(unet_nonlocal_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
        self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1],
                                         inter_channels=filters[1] // 4,
                                         sub_sample_factor=nonlocal_sf,
                                         mode=nonlocal_mode)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
        self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2],
                                         inter_channels=filters[2] // 4,
                                         sub_sample_factor=nonlocal_sf,
                                         mode=nonlocal_mode)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
Beispiel #4
0
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 n=2,
                 ks=3,
                 stride=1,
                 padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.BatchNorm2d(out_size),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(
                    nn.Conv2d(in_size, out_size, ks, s, p),
                    nn.ReLU(inplace=True),
                )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
Beispiel #5
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 nonlocal_mode='concatenation',
                 attention_dsample=(2, 2, 2),
                 is_batchnorm=True):
        super(unet_grid_attention_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)
        self.gating = UnetGridGatingSignal3(filters[4],
                                            filters[3],
                                            kernel_size=(1, 1, 1),
                                            is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = GridAttentionBlock3D(
            in_channels=filters[1],
            gating_channels=filters[3],
            inter_channels=filters[1],
            sub_sample_factor=attention_dsample,
            mode=nonlocal_mode)
        self.attentionblock3 = GridAttentionBlock3D(
            in_channels=filters[2],
            gating_channels=filters[3],
            inter_channels=filters[2],
            sub_sample_factor=attention_dsample,
            mode=nonlocal_mode)
        self.attentionblock4 = GridAttentionBlock3D(
            in_channels=filters[3],
            gating_channels=filters[3],
            inter_channels=filters[3],
            sub_sample_factor=attention_dsample,
            mode=nonlocal_mode)

        # upsampling
        self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv,
                                  self.is_batchnorm)
        self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv,
                                  self.is_batchnorm)
        self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv,
                                  self.is_batchnorm)
        self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv,
                                  self.is_batchnorm)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
Beispiel #6
0
    def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
                 sub_sample_factor=4, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]
        assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']

        # print('Dimension: %d, mode: %s' % (dimension, mode))

        self.mode = mode
        self.dimension = dimension
        self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool = nn.MaxPool3d
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool = nn.MaxPool2d
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool = nn.MaxPool1d
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant(self.W[1].weight, 0)
            nn.init.constant(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant(self.W.weight, 0)
            nn.init.constant(self.W.bias, 0)

        self.theta = None
        self.phi = None

        if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                                 kernel_size=1, stride=1, padding=0)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)

            if mode in ['concatenation']:
                self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)
                self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)
            elif mode in ['concat_proper', 'concat_proper_down']:
                self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,
                                     padding=0, bias=True)

        if mode == 'embedded_gaussian':
            self.operation_function = self._embedded_gaussian
        elif mode == 'dot_product':
            self.operation_function = self._dot_product
        elif mode == 'gaussian':
            self.operation_function = self._gaussian
        elif mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concat_proper':
            self.operation_function = self._concatenation_proper
        elif mode == 'concat_proper_down':
            self.operation_function = self._concatenation_proper_down
        else:
            raise NotImplementedError('Unknown operation function.')

        if any(ss > 1 for ss in self.sub_sample_factor):
            self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))
            if self.phi is None:
                self.phi = max_pool(kernel_size=sub_sample_factor)
            else:
                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))
            if mode == 'concat_proper_down':
                self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))

        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self,
                 feature_scale=4,
                 n_classes=2,
                 is_deconv=True,
                 in_channels=4,
                 nonlocal_mode='concatenation',
                 attention_dsample=(2, 2, 2),
                 is_batchnorm=True):
        super(unet_pCT_multi_att_dsv_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels,
                               filters[0],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv2 = UnetConv3(filters[0],
                               filters[1],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv3 = UnetConv3(filters[1],
                               filters[2],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv4 = UnetConv3(filters[2],
                               filters[3],
                               self.is_batchnorm,
                               kernel_size=(3, 3, 3),
                               padding_size=(1, 1, 1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.center = UnetConv3(filters[3],
                                filters[4],
                                self.is_batchnorm,
                                kernel_size=(3, 3, 3),
                                padding_size=(1, 1, 1))
        self.gating = UnetGridGatingSignal3(filters[4],
                                            filters[4],
                                            kernel_size=(1, 1, 1),
                                            is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = MultiAttentionBlock(
            in_size=filters[1],
            gate_size=filters[2],
            inter_size=filters[1],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(
            in_size=filters[2],
            gate_size=filters[3],
            inter_size=filters[2],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(
            in_size=filters[3],
            gate_size=filters[4],
            inter_size=filters[3],
            nonlocal_mode=nonlocal_mode,
            sub_sample_factor=attention_dsample)

        # upsampling
        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv3(in_size=filters[3],
                             out_size=n_classes,
                             scale_factor=8)
        self.dsv3 = UnetDsv3(in_size=filters[2],
                             out_size=n_classes,
                             scale_factor=4)
        self.dsv2 = UnetDsv3(in_size=filters[1],
                             out_size=n_classes,
                             scale_factor=2)
        self.dsv1 = nn.Conv3d(in_channels=filters[0],
                              out_channels=n_classes,
                              kernel_size=1)

        # final conv (without any concat)
        self.final = nn.Conv3d(n_classes * 4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')