def __init__(self, in_size, out_size, is_batchnorm): super(UnetGatingSignal3, self).__init__() self.fmap_size = (4, 4, 4) if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.InstanceNorm3d(in_size // 2), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): super(MultiAttentionBlock, self).__init__() self.gate_block_1 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.gate_block_2 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.combine_gates = nn.Sequential( nn.Conv3d(in_size * 2, in_size, kernel_size=1, stride=1, padding=0), nn.BatchNorm3d(in_size), nn.ReLU(inplace=True)) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3, 3, 1), padding_size=(1, 1, 0), init_stride=(1, 1, 1)): super(UnetConv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.InstanceNorm3d(out_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.InstanceNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm=True): super(UnetUp3_CT, self).__init__() self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): super(unet_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 3, 3, 3), padding_size=(1, 1, 1)) # upsampling self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) self.dropout1 = nn.Dropout(p=0.3) self.dropout2 = nn.Dropout(p=0.3) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('unetConv2') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): super(UnetUp3, self).__init__() if is_deconv: self.conv = UnetConv3(in_size, out_size, is_batchnorm) self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4, 4, 1), stride=(2, 2, 1), padding=(1, 1, 0)) else: self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm) self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(unetConv2, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size else: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, kernel_size=(1, 1, 1), is_batchnorm=True): super(UnetGridGatingSignal3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.InstanceNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, nonlocal_mode='concatenation', attention_dsample=(2, 2, 2), is_batchnorm=True): super(Attention_UNet, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) # attention blocks self.attentionblock2 = MultiAttentionBlock( in_size=filters[1], gate_size=filters[2], inter_size=filters[1], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) self.attentionblock3 = MultiAttentionBlock( in_size=filters[2], gate_size=filters[3], inter_size=filters[2], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) self.attentionblock4 = MultiAttentionBlock( in_size=filters[3], gate_size=filters[4], inter_size=filters[3], nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) # upsampling self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) # deep supervision self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) # final conv (without any concat) self.final = nn.Conv3d(n_classes * 4, n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', sub_sample_factor=(2, 2, 2)): super(_GridAttentionBlockND, self).__init__() assert dimension in [2, 3] assert mode in [ 'concatenation', 'concatenation_debug', 'concatenation_residual' ] # Downsampling rate for the input featuremap if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension # Default parameter set self.mode = mode self.dimension = dimension self.sub_sample_kernel_size = self.sub_sample_factor # Number of channels (pixel dimensions) self.in_channels = in_channels self.gating_channels = gating_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d bn = nn.BatchNorm3d self.upsample_mode = 'trilinear' elif dimension == 2: conv_nd = nn.Conv2d bn = nn.BatchNorm2d self.upsample_mode = 'bilinear' else: raise NotImplemented # Output transform self.W = nn.Sequential( conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels), ) # Theta^T * x_ij + Phi^T * gating_signal + bias self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0, bias=True) self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) # Initialise weights for m in self.children(): init_weights(m, init_type='kaiming') # Define the operation if mode == 'concatenation': self.operation_function = self._concatenation elif mode == 'concatenation_debug': self.operation_function = self._concatenation_debug elif mode == 'concatenation_residual': self.operation_function = self._concatenation_residual else: raise NotImplementedError('Unknown operation function.')
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', sub_sample_factor=(1, 1, 1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): super(_GridAttentionBlockND_TORR, self).__init__() assert dimension in [2, 3] assert mode in [ 'concatenation', 'concatenation_softmax', 'concatenation_sigmoid', 'concatenation_mean', 'concatenation_range_normalise', 'concatenation_mean_flow' ] # Default parameter set self.mode = mode self.dimension = dimension self.sub_sample_factor = sub_sample_factor if isinstance( sub_sample_factor, tuple) else tuple([sub_sample_factor]) * dimension self.sub_sample_kernel_size = self.sub_sample_factor # Number of channels (pixel dimensions) self.in_channels = in_channels self.gating_channels = gating_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d bn = nn.BatchNorm3d self.upsample_mode = 'trilinear' elif dimension == 2: conv_nd = nn.Conv2d bn = nn.BatchNorm2d self.upsample_mode = 'bilinear' else: raise NotImplemented # initialise id functions # Theta^T * x_ij + Phi^T * gating_signal + bias self.W = lambda x: x self.theta = lambda x: x self.psi = lambda x: x self.phi = lambda x: x self.nl1 = lambda x: x if use_W: if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels), ) else: self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) if use_theta: self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_phi: self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_psi: self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) if nonlinearity1: if nonlinearity1 == 'relu': self.nl1 = lambda x: F.relu(x, inplace=True) if 'concatenation' in mode: self.operation_function = self._concatenation else: raise NotImplementedError('Unknown operation function.') # Initialise weights for m in self.children(): init_weights(m, init_type='kaiming') if use_psi and self.mode == 'concatenation_sigmoid': nn.init.constant(self.psi.bias.data, 3.0) if use_psi and self.mode == 'concatenation_softmax': nn.init.constant(self.psi.bias.data, 10.0) # if use_psi and self.mode == 'concatenation_mean': # nn.init.constant(self.psi.bias.data, 3.0) # if use_psi and self.mode == 'concatenation_range_normalise': # nn.init.constant(self.psi.bias.data, 3.0) parallel = False if parallel: if use_W: self.W = nn.DataParallel(self.W) if use_phi: self.phi = nn.DataParallel(self.phi) if use_psi: self.psi = nn.DataParallel(self.psi) if use_theta: self.theta = nn.DataParallel(self.theta)