def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): super(MultiAttentionBlock, self).__init__() self.gate_block_1 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.gate_block_2 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.combine_gates = nn.Sequential( nn.Conv3d(in_size * 2, in_size, kernel_size=1, stride=1, padding=0), nn.BatchNorm3d(in_size), nn.ReLU(inplace=True)) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3, 3, 1), padding_size=(1, 1, 0), init_stride=(1, 1, 1)): super(UnetConv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True, nonlocal_mode='embedded_gaussian', nonlocal_sf=4): super(unet_nonlocal_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1], inter_channels=filters[1] // 4, sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2], inter_channels=filters[2] // 4, sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) # upsampling self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv) self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv) self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv) self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(unetConv2, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size else: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, nonlocal_mode='concatenation', attention_dsample=(2, 2, 2), is_batchnorm=True): super(unet_grid_attention_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) self.gating = UnetGridGatingSignal3(filters[4], filters[3], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) # attention blocks self.attentionblock2 = GridAttentionBlock3D( in_channels=filters[1], gating_channels=filters[3], inter_channels=filters[1], sub_sample_factor=attention_dsample, mode=nonlocal_mode) self.attentionblock3 = GridAttentionBlock3D( in_channels=filters[2], gating_channels=filters[3], inter_channels=filters[2], sub_sample_factor=attention_dsample, mode=nonlocal_mode) self.attentionblock4 = GridAttentionBlock3D( in_channels=filters[3], gating_channels=filters[3], inter_channels=filters[3], sub_sample_factor=attention_dsample, mode=nonlocal_mode) # upsampling self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, self.is_batchnorm) self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, self.is_batchnorm) self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, self.is_batchnorm) self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, self.is_batchnorm) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', sub_sample_factor=4, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] # print('Dimension: %d, mode: %s' % (dimension, mode)) self.mode = mode self.dimension = dimension self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool = nn.MaxPool3d bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool = nn.MaxPool2d bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool = nn.MaxPool1d bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant(self.W[1].weight, 0) nn.init.constant(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant(self.W.weight, 0) nn.init.constant(self.W.bias, 0) self.theta = None self.phi = None if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if mode in ['concatenation']: self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) elif mode in ['concat_proper', 'concat_proper_down']: self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) if mode == 'embedded_gaussian': self.operation_function = self._embedded_gaussian elif mode == 'dot_product': self.operation_function = self._dot_product elif mode == 'gaussian': self.operation_function = self._gaussian elif mode == 'concatenation': self.operation_function = self._concatenation elif mode == 'concat_proper': self.operation_function = self._concatenation_proper elif mode == 'concat_proper_down': self.operation_function = self._concatenation_proper_down else: raise NotImplementedError('Unknown operation function.') if any(ss > 1 for ss in self.sub_sample_factor): self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) if self.phi is None: self.phi = max_pool(kernel_size=sub_sample_factor) else: self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) if mode == 'concat_proper_down': self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) # Initialise weights for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=2, is_deconv=True, in_channels=4, nonlocal_mode='concatenation', attention_dsample=(2, 2, 2), is_batchnorm=True): super(unet_pCT_multi_att_dsv_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) # attention blocks self.attentionblock2 = MultiAttentionBlock( in_size=filters[1], gate_size=filters[2], inter_size=filters[1], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) self.attentionblock3 = MultiAttentionBlock( in_size=filters[2], gate_size=filters[3], inter_size=filters[2], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) self.attentionblock4 = MultiAttentionBlock( in_size=filters[3], gate_size=filters[4], inter_size=filters[3], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) # upsampling self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) # deep supervision self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) # final conv (without any concat) self.final = nn.Conv3d(n_classes * 4, n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')