コード例 #1
0
    def __init__(self, in_size, out_size, is_batchnorm):
        super(UnetGatingSignal3, self).__init__()
        self.fmap_size = (4, 4, 4)

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.BatchNorm3d(in_size // 2),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1),
                          (0, 0, 0)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool3d(output_size=self.fmap_size),
            )
            self.fc1 = nn.Linear(in_features=(in_size // 2) *
                                 self.fmap_size[0] * self.fmap_size[1] *
                                 self.fmap_size[2],
                                 out_features=out_size,
                                 bias=True)

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #2
0
ファイル: utils.py プロジェクト: Myyyr/segmentation3D
    def __init__(self,
                 in_size,
                 out_size,
                 kernel=(3, 3),
                 pad=(1, 1),
                 stride=(1, 1),
                 bn=True):
        super(UNetConv2D, self).__init__()
        if bn:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel, stride, pad),
                nn.BatchNorm2d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_size, out_size, kernel, 1, pad),
                nn.BatchNorm2d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel, stride, pad),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_size, out_size, kernel, 1, pad),
                nn.ReLU(inplace=True),
            )

        #initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #3
0
    def __init__(self,
                 in_size,
                 out_size,
                 is_batchnorm,
                 kernel_size=(3, 3, 1),
                 padding_size=(1, 1, 0),
                 init_stride=(1, 1, 1)):
        super(UnetConv3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, init_stride,
                          padding_size),
                nn.ReLU(inplace=True),
            )
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                nn.ReLU(inplace=True),
            )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #4
0
    def __init__(self, d_model, n_heads=2):
        super(CrossAttention, self).__init__()

        self.d_model = d_model
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)

        self.all_w = []
        for i in range(self.n_heads):
            wq = nn.Linear(self.d_model, self.d_model, bias=False)
            wk = nn.Linear(self.d_model, self.d_model, bias=False)
            wv = nn.Linear(self.d_model, self.d_model, bias=False)
            self.all_w.append(nn.ModuleList([wq, wk, wv]))
        self.all_w = nn.ModuleList(self.all_w)

        self.wo = nn.Linear(self.d_model * self.n_heads,
                            self.d_model,
                            bias=False)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

        self.feed_forward = nn.Linear(d_model, d_model)

        #initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode,
                 sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.gate_block_2 = GridAttentionBlock3D(
            in_channels=in_size,
            gating_channels=gate_size,
            inter_channels=inter_size,
            mode=nonlocal_mode,
            sub_sample_factor=sub_sample_factor)
        self.combine_gates = nn.Sequential(
            nn.Conv3d(in_size * 2, in_size, kernel_size=1,
                      stride=1, padding=0), nn.BatchNorm3d(in_size),
            nn.ReLU(inplace=True))

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock3D') != -1:
                continue
            init_weights(m, init_type='kaiming')
コード例 #6
0
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        else:
            for i in range(1, n+1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True),)
                setattr(self, 'conv%d'%i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #7
0
ファイル: utils.py プロジェクト: Myyyr/segmentation3D
    def __init__(self, d_model, filters, mode="deconv", bn=True):
        super(UNETRSkip, self).__init__()

        self.d_model = d_model
        self.filters = [d_model] + filters
        self.n_module = len(filters)
        self.bn = True

        self.module_list = []
        for i in range(self.n_module):
            l = [
                nn.Conv3d(self.filters[i],
                          self.filters[i + 1],
                          kernel_size=3,
                          padding=1)
            ]
            if bn: l.append(nn.BatchNorm3d(self.filters[i + 1]))
            l.append(nn.ReLU(inplace=True))
            l.append(
                nn.ConvTranspose3d(self.filters[i + 1],
                                   self.filters[i + 1], (2, 2, 2),
                                   stride=(2, 2, 2)))

            self.module_list.append(nn.Sequential(*l))
        self.module_list = nn.Sequential(*self.module_list)

        for m in self.children():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.ConvTranspose3d):
                init_weights(m, init_type='kaiming')
コード例 #8
0
    def __init__(self, in_size, out_size, is_batchnorm=True, depth=3):
        super(UnetUp3_dense_CT, self).__init__()
        self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear')

        self.relu = nn.ReLU(inplace=False)
        self.is_batchnorm = is_batchnorm

        # Depth of the dense block
        # -----------------------------------
        if depth < 1:
            depth = 1  # Minimal depth is 1
        self.dense_depth = depth

        # Initial convolution operation on the input data
        self.init_conv = nn.Conv3d(in_size + out_size, out_size, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1))

        # Define further convolutions in a for loop
        self.ops_dict = {}
        for ii in range(depth - 1):
            name = "op_{}".format(ii + 1)
            self.ops_dict[name] = nn.Conv3d((ii + 1) * out_size, out_size, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)).cuda()

        self.batchnorm = nn.BatchNorm3d(out_size, track_running_stats=True)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #9
0
    def __init__(self, channels, groups):
        super(ResidualInner, self).__init__()
        # self.gn = nn.BatchNorm3d(channels)
        self.gn = nn.GroupNorm(groups, channels)
        self.conv = nn.Conv3d(channels, channels, 3, padding=1, bias=False)

        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #10
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 in_channels=3,
                 is_batchnorm=True,
                 n_convs=None):
        super(sononet, self).__init__()
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale
        self.n_classes = n_classes

        filters = [64, 128, 256, 512]
        filters = [int(x / self.feature_scale) for x in filters]

        if n_convs is None:
            n_convs = [2, 2, 3, 3, 3]

        # downsampling
        self.conv1 = unetConv2(self.in_channels,
                               filters[0],
                               self.is_batchnorm,
                               n=n_convs[0])
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0],
                               filters[1],
                               self.is_batchnorm,
                               n=n_convs[1])
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1],
                               filters[2],
                               self.is_batchnorm,
                               n=n_convs[2])
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2],
                               filters[3],
                               self.is_batchnorm,
                               n=n_convs[3])
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5 = unetConv2(filters[3],
                               filters[3],
                               self.is_batchnorm,
                               n=n_convs[4])

        # adaptation layer
        self.conv5_p = conv2DBatchNormRelu(filters[3], filters[2], 1, 1, 0)
        self.conv6_p = conv2DBatchNorm(filters[2], self.n_classes, 1, 1, 0)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')
コード例 #11
0
    def __init__(self, in_size, out_size, is_batchnorm=True):
        super(UnetUp3_CT, self).__init__()
        self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #12
0
 def __init__(self, inChannels, outChannels, depth, upsample=True):
     super(DecoderModule, self).__init__()
     self.reversibleBlocks = makeReversibleComponent(inChannels, depth)
     self.upsample = upsample
     if self.upsample:
         self.conv = nn.Conv3d(inChannels, outChannels, 1)
     for m in self.children():
         if isinstance(m, nn.Conv3d):
             init_weights(m, init_type='kaiming')
コード例 #13
0
    def __init__(self, in_size, out_size, init_kernel=3, init_padding=1, is_batchnorm=True):
        super(UnetUp3_CT_deformable, self).__init__()
        self.conv = UnetConv3_deformable(in_size + out_size, out_size, is_batchnorm, kernel_size=init_kernel, padding_size=init_padding)
        self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=False)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3_deformable') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #14
0
ファイル: UNet.py プロジェクト: Myyyr/segmentation3D
    def __init__(self,
                 filters,
                 n_classes=2,
                 in_channels=1,
                 dim='2d',
                 bn=True,
                 up_mode='biline'):
        super(UNet, self).__init__()

        self.in_channels = in_channels
        self.dim = dim
        self.filters = filters
        self.UNetConv = {'2d': UNetConv2D, '3d': UNetConv3D}[self.dim]
        self.UNetUpLayer = {'2d': UnetUp2D, '3d': UnetUp3D}[self.dim]
        self.maxpool = {'2d': nn.MaxPool2d, '3d': nn.MaxPool3d}[self.dim]
        self.final_layer = {'2d': nn.Conv2d, '3d': nn.Conv3d}[self.dim]
        # encoder
        self.conv1 = self.UNetConv(self.in_channels, filters[0], bn=bn)
        self.maxpool1 = self.maxpool(kernel_size=2)

        self.conv2 = self.UNetConv(filters[0], filters[1], bn=bn)
        self.maxpool2 = self.maxpool(kernel_size=2)

        self.conv3 = self.UNetConv(filters[1], filters[2], bn=bn)
        self.maxpool3 = self.maxpool(kernel_size=2)

        self.conv4 = self.UNetConv(filters[2], filters[3], bn=bn)
        self.maxpool4 = self.maxpool(kernel_size=2)

        self.center = self.UNetConv(filters[3], filters[4], bn=bn)

        # upsampling
        self.up_concat4 = self.UNetUpLayer(filters[4],
                                           filters[3],
                                           bn=bn,
                                           up_mode=up_mode)
        self.up_concat3 = self.UNetUpLayer(filters[3],
                                           filters[2],
                                           bn=bn,
                                           up_mode=up_mode)
        self.up_concat2 = self.UNetUpLayer(filters[2],
                                           filters[1],
                                           bn=bn,
                                           up_mode=up_mode)
        self.up_concat1 = self.UNetUpLayer(filters[1],
                                           filters[0],
                                           bn=bn,
                                           up_mode=up_mode)

        # final conv (without any concat)
        self.final = self.final_layer(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, self.final_layer):
                init_weights(m, init_type='kaiming')
    def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
                 nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
        super(unet_CT_dense_multi_att_dsv_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = DenseBlock3D(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv2 = DenseBlock3D(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv3 = DenseBlock3D(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.conv4 = DenseBlock3D(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))

        self.center = DenseBlock3D(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)

        # upsampling
        self.up_concat4 = UnetUp3_dense_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_dense_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_dense_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_dense_CT(filters[1], filters[0], is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
        self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
        self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
        self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)

        # final conv (without any concat)
        self.final = nn.Conv3d(n_classes*4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
コード例 #16
0
    def __init__(self,
                 n_classes,
                 feature_scale=4,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True,
                 nonlocal_mode='embedded_gaussian',
                 nonlocal_sf=4):
        super(unet_nonlocal_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
        self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1],
                                         inter_channels=filters[1] // 4,
                                         sub_sample_factor=nonlocal_sf,
                                         mode=nonlocal_mode)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
        self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2],
                                         inter_channels=filters[2] // 4,
                                         sub_sample_factor=nonlocal_sf,
                                         mode=nonlocal_mode)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
コード例 #17
0
    def __init__(self, base_model, filters=512, n_classes=14):
        super(DebugCrossPatch3DTr, self).__init__()

        self.base_model = base_model
        # print(self.base_model)
        self.final_conv = nn.Conv3d(filters, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
コード例 #18
0
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, is_batchnorm=False, n=1)
        if is_deconv:
            self.up = nn.Sequential(nn.Upsample(scale_factor=2),
                                    RRCNN_block(in_size, out_size))

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #19
0
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #20
0
    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
        super(UnetUp3, self).__init__()
        if is_deconv:
            self.conv = UnetConv3(in_size, out_size, is_batchnorm)
            self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0))
        else:
            self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm)
            self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear')

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
コード例 #21
0
    def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
                 nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
        super(unet_grid_attention_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)
        self.gating = UnetGridGatingSignal3(filters[4], filters[3], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)

        # attention blocks
        self.attentionblock2 = GridAttentionBlock3D(in_channels=filters[1], gating_channels=filters[3],
                                                    inter_channels=filters[1], sub_sample_factor=attention_dsample, mode=nonlocal_mode)
        self.attentionblock3 = GridAttentionBlock3D(in_channels=filters[2], gating_channels=filters[3],
                                                    inter_channels=filters[2], sub_sample_factor=attention_dsample, mode=nonlocal_mode)
        self.attentionblock4 = GridAttentionBlock3D(in_channels=filters[3], gating_channels=filters[3],
                                                    inter_channels=filters[3], sub_sample_factor=attention_dsample, mode=nonlocal_mode)

        # upsampling
        self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, self.is_batchnorm)
        self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, self.is_batchnorm)
        self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, self.is_batchnorm)
        self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, self.is_batchnorm)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
コード例 #22
0
    def __init__(self, filters = [16, 32, 64, 128, 256], patch_size = [1,1,1], d_model = 256,n_classes=14, in_channels=1, n_cheads=2, n_sheads=8, bn = True, up_mode='deconv', n_strans=6, do_cross=False):
        super(CrossPatch3DTr, self).__init__()
        self.PE = None

        self.in_channels = in_channels
        self.filters = filters
        self.n_sheads = n_sheads
        self.d_model = d_model
        self.patch_size = patch_size
        self.do_cross = do_cross

        # CNN + Trans encoder
        self.encoder = SelfTransEncoder(filters=filters, patch_size=patch_size, d_model=d_model, in_channels=in_channels, n_sheads=n_sheads, bn=bn, n_strans=n_strans)


        # Transformer for cross attention
        # self.avgpool = nn.AvgPool3d((4,4,2), (4,4,2))
        # self.positional_encoder = PositionalEncoding(self.d_model, dropout=0.1, max_len = 20000)
        self.p_enc_3d = PositionalEncodingPermute3D(filters[-1])
        self.cross_trans = CrossAttention(self.d_model, n_cheads)


        # CNN decoder 
        self.before_d_model = filters[4]*np.prod(self.patch_size)
        ## Rescale progressively feature map from cross attention
        # a = int(self.before_d_model/self.patch_size[0])
        # b = int(a/self.patch_size[1])
        # c = int(b/self.patch_size[2])
        # self.center = nn.Sequential(nn.ConvTranspose3d(self.d_model, a, 2, stride=2),
        #                             nn.Conv3d(a,b, 3, padding=1),
        #                             nn.Conv3d(b,c, 3, padding=1))

        ## Decode like 3D UNet
        self.up_concat4 = UnetUp3D(filters[4], filters[3], bn=bn, up_mode=up_mode)
        self.up_concat3 = UnetUp3D(filters[3], filters[2], bn=bn, up_mode=up_mode)
        self.up_concat2 = UnetUp3D(filters[2], filters[1], bn=bn, up_mode=up_mode)
        self.up_concat1 = UnetUp3D(filters[1], filters[0], bn=bn, up_mode=up_mode)
        

        self.final_conv = nn.Conv3d(filters[0], n_classes, 1)

        # Deep Supervision
        self.ds_cv1 = nn.Conv3d(filters[3], n_classes, 1)
        self.ds_cv2 = nn.Conv3d(filters[2], n_classes, 1)
        self.ds_cv3 = nn.Conv3d(filters[1], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
コード例 #23
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=False,
                 in_channels=3,
                 is_batchnorm=True):
        super(cloudSegNet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [16, 8, 8, 8, 8]

        # downsampling
        self.down1 = unetConv2(self.in_channels,
                               filters[0],
                               self.is_batchnorm,
                               n=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.down2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.down3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=1)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        # upsampling
        self.up_concat3 = cloudSegNetUp(filters[3],
                                        filters[2],
                                        self.is_deconv,
                                        n=1)
        self.up_concat2 = cloudSegNetUp(filters[2],
                                        filters[1],
                                        self.is_deconv,
                                        n=1)
        self.up_concat1 = cloudSegNetUp(filters[1],
                                        filters[0],
                                        self.is_deconv,
                                        n=1)

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, kernel_size=5, padding=2)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')
コード例 #24
0
    def reinit_decoder(self):
        ## Decode like 3D UNet
        del self.up_concat4, self.up_concat3, self.up_concat2, self.up_concat1
        self.up_concat4 = UnetUp3D(self.filters[4],
                                   self.filters[3],
                                   bn=self.bn,
                                   up_mode=self.up_mode)
        self.up_concat3 = UnetUp3D(self.filters[3],
                                   self.filters[2],
                                   bn=self.bn,
                                   up_mode=self.up_mode)
        self.up_concat2 = UnetUp3D(self.filters[2],
                                   self.filters[1],
                                   bn=self.bn,
                                   up_mode=self.up_mode)
        self.up_concat1 = UnetUp3D(self.filters[1],
                                   self.filters[0],
                                   bn=self.bn,
                                   up_mode=self.up_mode)

        del self.final_conv
        self.final_conv = nn.Conv3d(self.filters[0], self.n_classes, 1)

        # Deep Supervision
        del self.ds_cv1, self.ds_cv2, self.ds_cv3
        self.ds_cv1 = nn.Conv3d(self.filters[3], self.n_classes, 1)
        self.ds_cv2 = nn.Conv3d(self.filters[2], self.n_classes, 1)
        self.ds_cv3 = nn.Conv3d(self.filters[1], self.n_classes, 1)

        init_weights(self.final_conv, init_type='kaiming')
        init_weights(self.ds_cv1, init_type='kaiming')
        init_weights(self.ds_cv2, init_type='kaiming')
        init_weights(self.ds_cv3, init_type='kaiming')
コード例 #25
0
    def __init__(self,
                 feature_scale=4,
                 n_classes=21,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True):
        super(unet_3D, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv,
                                  is_batchnorm)
        self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv,
                                  is_batchnorm)
        self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv,
                                  is_batchnorm)
        self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv,
                                  is_batchnorm)

        # final conv (without any concat)
        self.final = nn.Conv3d(filters[0], n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
コード例 #26
0
    def __init__(self,
                 filters=[16, 32, 64, 128, 256, 512],
                 use_trans=[1, 1, 1, 1, 1, 1],
                 in_channels=1,
                 n_sheads=8,
                 bn=True,
                 n_strans=6):
        super(SelfTransEncoder, self).__init__()
        self.in_channels = in_channels
        self.filters = filters
        self.n_sheads = n_sheads
        self.use_trans = use_trans

        # CNN encoder
        # self.first_conv = nn.Conv3d(self.in_channels, filters[0], 1)
        # self.conv1 = UNetConv3D(filters[0], filters[0], bn=bn)
        self.conv1 = UNetConv3D(self.in_channels, filters[0], bn=bn)
        if use_trans[0]:
            self.trans1 = SimpleTransEncoder(filters[0], n_sheads, n_strans)

        self.maxpool2 = nn.MaxPool3d(kernel_size=2)
        self.conv2 = UNetConv3D(filters[0], filters[1], bn=bn)
        if use_trans[1]:
            self.trans2 = SimpleTransEncoder(filters[1], n_sheads, n_strans)

        self.maxpool3 = nn.MaxPool3d(kernel_size=2)
        self.conv3 = UNetConv3D(filters[1], filters[2], bn=bn)
        if use_trans[2]:
            self.trans3 = SimpleTransEncoder(filters[2], n_sheads, n_strans)

        self.maxpool4 = nn.MaxPool3d(kernel_size=2)
        self.conv4 = UNetConv3D(filters[2], filters[3], bn=bn)
        if use_trans[3]:
            self.trans4 = SimpleTransEncoder(filters[3], n_sheads, n_strans)

        self.maxpool5 = nn.MaxPool3d(kernel_size=2)
        self.conv5 = UNetConv3D(filters[3], filters[4], bn=bn)
        if use_trans[4]:
            self.trans5 = SimpleTransEncoder(filters[4], n_sheads, n_strans)

        self.maxpool6 = nn.MaxPool3d(kernel_size=(2, 2, 1))
        self.conv6 = UNetConv3D(filters[4], filters[5], bn=bn)
        if use_trans[5]:
            self.trans6 = SimpleTransEncoder(filters[5], n_sheads, n_strans)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
コード例 #27
0
ファイル: revunet_3D.py プロジェクト: Myyyr/segmentation3D
 def __init__(self,
              inChannels,
              outChannels,
              depth,
              downsample=True,
              groups=2):
     super(EncoderModule, self).__init__()
     self.downsample = downsample
     if downsample:
         self.conv = nn.Conv3d(inChannels, outChannels, 1)
     self.reversibleBlocks = makeReversibleComponent(
         outChannels, depth, groups)
     for m in self.children():
         if isinstance(m, nn.Conv3d):
             init_weights(m, init_type='kaiming')
コード例 #28
0
    def __init__(self, inc, outc=[], kernel_size=3, padding=1, bias=None):
        super(DeformConv3D, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        #self.zero_padding = nn.functional.pad(padding)
        self.conv_kernel = nn.Conv3d(inc,
                                     outc,
                                     kernel_size=kernel_size,
                                     stride=kernel_size,
                                     bias=bias)

        # TODO: remove before uploading to GitHub
        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
コード例 #29
0
    def __init__(self,
                 filters=[16, 32, 64, 128, 256],
                 patch_size=[1, 1, 1],
                 d_model=256,
                 in_channels=1,
                 n_sheads=8,
                 bn=True,
                 n_strans=6):
        super(SelfTransEncoder, self).__init__()
        self.in_channels = in_channels
        self.filters = filters
        self.n_sheads = n_sheads
        self.d_model = d_model
        self.patch_size = patch_size

        # CNN encoder
        # self.first_conv = nn.Conv3d(self.in_channels, filters[0], 1)
        # self.conv1 = UNetConv3D(filters[0], filters[0], bn=bn)
        self.conv1 = UNetConv3D(self.in_channels, filters[0], bn=bn)

        self.maxpool2 = nn.MaxPool3d(kernel_size=2)
        self.conv2 = UNetConv3D(filters[0], filters[1], bn=bn)

        self.maxpool3 = nn.MaxPool3d(kernel_size=2)
        self.conv3 = UNetConv3D(filters[1], filters[2], bn=bn)

        self.maxpool4 = nn.MaxPool3d(kernel_size=2)
        self.conv4 = UNetConv3D(filters[2], filters[3], bn=bn)

        self.maxpool5 = nn.MaxPool3d(kernel_size=2)
        self.conv5 = UNetConv3D(filters[3], filters[4], bn=bn)

        # Transformer for self attention
        self.before_d_model = filters[4] * np.prod(self.patch_size)
        self.linear = nn.Linear(self.before_d_model, self.d_model)
        # self.positional_encoder = PositionalEncoding(self.d_model, dropout=0.1, max_len = 1000)
        # self.p_enc_3d = PositionalEncodingPermute3D(filters[-1])
        trans_layer = nn.TransformerEncoderLayer(d_model=self.d_model,
                                                 nhead=self.n_sheads)
        self.self_trans = nn.TransformerEncoder(trans_layer, n_strans)

        # Feed Forward projection
        self.last = nn.Linear(self.d_model, self.before_d_model)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
コード例 #30
0
    def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True):
        super(UnetGridGatingSignal3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)),
                                       nn.BatchNorm3d(out_size, track_running_stats=True), # OS
                                       nn.ReLU(inplace=True),
                                       )
        else:
            self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)),
                                       nn.ReLU(inplace=True),
                                       )

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')