def __init__(self, in_size, out_size, is_batchnorm): super(UnetGatingSignal3, self).__init__() self.fmap_size = (4, 4, 4) if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.BatchNorm3d(in_size // 2), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3, 3, 1), padding_size=(1, 1, 0), init_stride=(1, 1, 1)): super(UnetConv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm=True): super(UnetUp3_CT, self).__init__() self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1)) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('unetConv2') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): super(UnetUp3, self).__init__() if is_deconv: self.conv = UnetConv3(in_size, out_size, is_batchnorm) self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4, 4, 1), stride=(2, 2, 1), padding=(1, 1, 0)) else: self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm) self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(unetConv2, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size else: for i in range(1, n + 1): conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), ) setattr(self, 'conv%d' % i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, kernel_size=(1, 1, 1), is_batchnorm=True): super(UnetGridGatingSignal3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', sub_sample_factor=(1, 1, 1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): super(_GridAttentionBlockND_TORR, self).__init__() assert dimension in [2, 3] assert mode in [ 'concatenation', 'concatenation_softmax', 'concatenation_sigmoid', 'concatenation_mean', 'concatenation_range_normalise', 'concatenation_mean_flow' ] # Default parameter set self.mode = mode self.dimension = dimension self.sub_sample_factor = sub_sample_factor if isinstance( sub_sample_factor, tuple) else tuple([sub_sample_factor]) * dimension self.sub_sample_kernel_size = self.sub_sample_factor # Number of channels (pixel dimensions) self.in_channels = in_channels self.gating_channels = gating_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d bn = nn.BatchNorm3d self.upsample_mode = 'trilinear' elif dimension == 2: conv_nd = nn.Conv2d bn = nn.BatchNorm2d self.upsample_mode = 'bilinear' else: raise NotImplemented # initialise id functions # Theta^T * x_ij + Phi^T * gating_signal + bias self.W = lambda x: x self.theta = lambda x: x self.psi = lambda x: x self.phi = lambda x: x self.nl1 = lambda x: x if use_W: if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels), ) else: self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) if use_theta: self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_phi: self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_psi: self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) if nonlinearity1: if nonlinearity1 == 'relu': self.nl1 = lambda x: F.relu(x, inplace=True) if 'concatenation' in mode: # self.operation_function = self._concatenation self.operation_function = Concatenation( W=self.W, phi=self.phi, psi=self.psi, theta=self.theta, nl1=self.nl1, mode=mode, upsample_mode=self.upsample_mode) else: raise NotImplementedError('Unknown operation function.') # Initialise weights for m in self.children(): init_weights(m, init_type='kaiming') if use_psi and self.mode == 'concatenation_sigmoid': nn.init.constant(self.psi.bias.data, 3.0) if use_psi and self.mode == 'concatenation_softmax': nn.init.constant(self.psi.bias.data, 10.0) # if use_psi and self.mode == 'concatenation_mean': # nn.init.constant(self.psi.bias.data, 3.0) # if use_psi and self.mode == 'concatenation_range_normalise': # nn.init.constant(self.psi.bias.data, 3.0) parallel = False if parallel: if use_W: self.W = nn.DataParallel(self.W) if use_phi: self.phi = nn.DataParallel(self.phi) if use_psi: self.psi = nn.DataParallel(self.psi) if use_theta: self.theta = nn.DataParallel(self.theta)
def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', sub_sample_factor=4, bn_layer=True): super(_NonLocalBlockND, self).__init__() assert dimension in [1, 2, 3] assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] # print('Dimension: %d, mode: %s' % (dimension, mode)) self.mode = mode self.dimension = dimension self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] self.in_channels = in_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d max_pool = nn.MaxPool3d bn = nn.BatchNorm3d elif dimension == 2: conv_nd = nn.Conv2d max_pool = nn.MaxPool2d bn = nn.BatchNorm2d else: conv_nd = nn.Conv1d max_pool = nn.MaxPool1d bn = nn.BatchNorm1d self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels) ) nn.init.constant(self.W[1].weight, 0) nn.init.constant(self.W[1].bias, 0) else: self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) nn.init.constant(self.W.weight, 0) nn.init.constant(self.W.bias, 0) self.theta = None self.phi = None if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) if mode in ['concatenation']: self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) elif mode in ['concat_proper', 'concat_proper_down']: self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) if mode == 'embedded_gaussian': self.operation_function = self._embedded_gaussian elif mode == 'dot_product': self.operation_function = self._dot_product elif mode == 'gaussian': self.operation_function = self._gaussian elif mode == 'concatenation': self.operation_function = self._concatenation elif mode == 'concat_proper': self.operation_function = self._concatenation_proper elif mode == 'concat_proper_down': self.operation_function = self._concatenation_proper_down else: raise NotImplementedError('Unknown operation function.') if any(ss > 1 for ss in self.sub_sample_factor): self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) if self.phi is None: self.phi = max_pool(kernel_size=sub_sample_factor) else: self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) if mode == 'concat_proper_down': self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) # Initialise weights for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None, nonlocal_mode='concatenation', aggregation_mode='concat'): super(resnet_grid_attention, self).__init__() self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale self.n_classes= n_classes self.aggregation_mode = aggregation_mode self.deep_supervised = True # # Resnet Pretrained Define self.resnet = torchvision.models.resnet18(pretrained=True) conv1_weight = self.resnet.conv1.weight.data self.resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,bias=False) self.resnet.conv1.weight = torch.nn.Parameter(torch.cat((conv1_weight,conv1_weight[:,:1,:,:]),dim=1)) filters = [64, 64, 128, 256, 512] # filters = [int(x / self.feature_scale) for x in filters] # # Feature Extraction self.conv1 = self.resnet.conv1 self.conv2 = self.resnet.layer1 self.conv3 = self.resnet.layer2 self.conv4 = self.resnet.layer3 self.conv5 = self.resnet.layer4 # if n_convs is None: # n_convs = [3, 3, 3, 2, 2] # filters = [64, 128, 256, 512] # filters = [int(x / self.feature_scale) for x in filters] #################### # # Feature Extraction # self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0]) # self.maxpool1 = nn.MaxPool2d(kernel_size=2) # self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1]) # self.maxpool2 = nn.MaxPool2d(kernel_size=2) # self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2]) # self.maxpool3 = nn.MaxPool2d(kernel_size=2) # self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3]) # self.maxpool4 = nn.MaxPool2d(kernel_size=2) # self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4]) ################ # Attention Maps self.compatibility_score1 = AttentionBlock2D(in_channels=filters[2], gating_channels=filters[4], inter_channels=filters[4], sub_sample_factor=(1,1), mode=nonlocal_mode, use_W=False, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu') self.compatibility_score2 = AttentionBlock2D(in_channels=filters[3], gating_channels=filters[4], inter_channels=filters[4], sub_sample_factor=(1,1), mode=nonlocal_mode, use_W=False, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu') ######################### # Aggreagation Strategies self.attention_filter_sizes = [filters[2], filters[3]] if aggregation_mode == 'concat': self.classifier = nn.Linear(filters[2]+filters[3]+filters[4], n_classes) self.aggregate = self.aggreagation_concat else: self.classifier1 = nn.Linear(filters[2], n_classes) self.classifier2 = nn.Linear(filters[3], n_classes) self.classifier3 = nn.Linear(filters[4], n_classes) self.classifiers = [self.classifier1, self.classifier2, self.classifier3] if aggregation_mode == 'mean': # self.aggregate = self.aggregation_sep self.aggregate = Aggregation_sep(self.classifiers) elif aggregation_mode == 'deep_sup': self.classifier = nn.Linear(filters[2] + filters[3] + filters[4], n_classes) # self.aggregate = self.aggregation_ds self.aggregate = Aggregation_ds(self.classifiers, self.classifier) elif aggregation_mode == 'ft': self.classifier = nn.Linear(n_classes*3, n_classes) # self.aggregate = self.aggregation_ft self.aggregate = Aggregation_ft(self.classifiers, self.classifier) else: raise NotImplementedError #################### # initialise weights # Freezing BatchNorm2D for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')