def __init__(self, in_size, out_size, is_batchnorm): super(UnetGatingSignal3, self).__init__() self.fmap_size = (4, 4, 4) if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.BatchNorm3d(in_size // 2), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, in_size // 2, (1, 1, 1), (1, 1, 1), (0, 0, 0)), nn.ReLU(inplace=True), nn.AdaptiveAvgPool3d(output_size=self.fmap_size), ) self.fc1 = nn.Linear(in_features=(in_size // 2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], out_features=out_size, bias=True) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, kernel=(3, 3), pad=(1, 1), stride=(1, 1), bn=True): super(UNetConv2D, self).__init__() if bn: self.conv1 = nn.Sequential( nn.Conv2d(in_size, out_size, kernel, stride, pad), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(out_size, out_size, kernel, 1, pad), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv2d(in_size, out_size, kernel, stride, pad), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(out_size, out_size, kernel, 1, pad), nn.ReLU(inplace=True), ) #initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3, 3, 1), padding_size=(1, 1, 0), init_stride=(1, 1, 1)): super(UnetConv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, d_model, n_heads=2): super(CrossAttention, self).__init__() self.d_model = d_model self.n_heads = n_heads self.norm1 = nn.LayerNorm(d_model, eps=1e-5) self.all_w = [] for i in range(self.n_heads): wq = nn.Linear(self.d_model, self.d_model, bias=False) wk = nn.Linear(self.d_model, self.d_model, bias=False) wv = nn.Linear(self.d_model, self.d_model, bias=False) self.all_w.append(nn.ModuleList([wq, wk, wv])) self.all_w = nn.ModuleList(self.all_w) self.wo = nn.Linear(self.d_model * self.n_heads, self.d_model, bias=False) self.norm2 = nn.LayerNorm(d_model, eps=1e-5) self.feed_forward = nn.Linear(d_model, d_model) #initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): super(MultiAttentionBlock, self).__init__() self.gate_block_1 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.gate_block_2 = GridAttentionBlock3D( in_channels=in_size, gating_channels=gate_size, inter_channels=inter_size, mode=nonlocal_mode, sub_sample_factor=sub_sample_factor) self.combine_gates = nn.Sequential( nn.Conv3d(in_size * 2, in_size, kernel_size=1, stride=1, padding=0), nn.BatchNorm3d(in_size), nn.ReLU(inplace=True)) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): super(unetConv2, self).__init__() self.n = n self.ks = ks self.stride = stride self.padding = padding s = stride p = padding if is_batchnorm: for i in range(1, n+1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True),) setattr(self, 'conv%d'%i, conv) in_size = out_size else: for i in range(1, n+1): conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True),) setattr(self, 'conv%d'%i, conv) in_size = out_size # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, d_model, filters, mode="deconv", bn=True): super(UNETRSkip, self).__init__() self.d_model = d_model self.filters = [d_model] + filters self.n_module = len(filters) self.bn = True self.module_list = [] for i in range(self.n_module): l = [ nn.Conv3d(self.filters[i], self.filters[i + 1], kernel_size=3, padding=1) ] if bn: l.append(nn.BatchNorm3d(self.filters[i + 1])) l.append(nn.ReLU(inplace=True)) l.append( nn.ConvTranspose3d(self.filters[i + 1], self.filters[i + 1], (2, 2, 2), stride=(2, 2, 2))) self.module_list.append(nn.Sequential(*l)) self.module_list = nn.Sequential(*self.module_list) for m in self.children(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.ConvTranspose3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm=True, depth=3): super(UnetUp3_dense_CT, self).__init__() self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') self.relu = nn.ReLU(inplace=False) self.is_batchnorm = is_batchnorm # Depth of the dense block # ----------------------------------- if depth < 1: depth = 1 # Minimal depth is 1 self.dense_depth = depth # Initial convolution operation on the input data self.init_conv = nn.Conv3d(in_size + out_size, out_size, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)) # Define further convolutions in a for loop self.ops_dict = {} for ii in range(depth - 1): name = "op_{}".format(ii + 1) self.ops_dict[name] = nn.Conv3d((ii + 1) * out_size, out_size, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)).cuda() self.batchnorm = nn.BatchNorm3d(out_size, track_running_stats=True) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, channels, groups): super(ResidualInner, self).__init__() # self.gn = nn.BatchNorm3d(channels) self.gn = nn.GroupNorm(groups, channels) self.conv = nn.Conv3d(channels, channels, 3, padding=1, bias=False) for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None): super(sononet, self).__init__() self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale self.n_classes = n_classes filters = [64, 128, 256, 512] filters = [int(x / self.feature_scale) for x in filters] if n_convs is None: n_convs = [2, 2, 3, 3, 3] # downsampling self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0]) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1]) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2]) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3]) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4]) # adaptation layer self.conv5_p = conv2DBatchNormRelu(filters[3], filters[2], 1, 1, 0) self.conv6_p = conv2DBatchNorm(filters[2], self.n_classes, 1, 1, 0) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_batchnorm=True): super(UnetUp3_CT, self).__init__() self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, inChannels, outChannels, depth, upsample=True): super(DecoderModule, self).__init__() self.reversibleBlocks = makeReversibleComponent(inChannels, depth) self.upsample = upsample if self.upsample: self.conv = nn.Conv3d(inChannels, outChannels, 1) for m in self.children(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, init_kernel=3, init_padding=1, is_batchnorm=True): super(UnetUp3_CT_deformable, self).__init__() self.conv = UnetConv3_deformable(in_size + out_size, out_size, is_batchnorm, kernel_size=init_kernel, padding_size=init_padding) self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear', align_corners=False) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3_deformable') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, filters, n_classes=2, in_channels=1, dim='2d', bn=True, up_mode='biline'): super(UNet, self).__init__() self.in_channels = in_channels self.dim = dim self.filters = filters self.UNetConv = {'2d': UNetConv2D, '3d': UNetConv3D}[self.dim] self.UNetUpLayer = {'2d': UnetUp2D, '3d': UnetUp3D}[self.dim] self.maxpool = {'2d': nn.MaxPool2d, '3d': nn.MaxPool3d}[self.dim] self.final_layer = {'2d': nn.Conv2d, '3d': nn.Conv3d}[self.dim] # encoder self.conv1 = self.UNetConv(self.in_channels, filters[0], bn=bn) self.maxpool1 = self.maxpool(kernel_size=2) self.conv2 = self.UNetConv(filters[0], filters[1], bn=bn) self.maxpool2 = self.maxpool(kernel_size=2) self.conv3 = self.UNetConv(filters[1], filters[2], bn=bn) self.maxpool3 = self.maxpool(kernel_size=2) self.conv4 = self.UNetConv(filters[2], filters[3], bn=bn) self.maxpool4 = self.maxpool(kernel_size=2) self.center = self.UNetConv(filters[3], filters[4], bn=bn) # upsampling self.up_concat4 = self.UNetUpLayer(filters[4], filters[3], bn=bn, up_mode=up_mode) self.up_concat3 = self.UNetUpLayer(filters[3], filters[2], bn=bn, up_mode=up_mode) self.up_concat2 = self.UNetUpLayer(filters[2], filters[1], bn=bn, up_mode=up_mode) self.up_concat1 = self.UNetUpLayer(filters[1], filters[0], bn=bn, up_mode=up_mode) # final conv (without any concat) self.final = self.final_layer(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, self.final_layer): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): super(unet_CT_dense_multi_att_dsv_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = DenseBlock3D(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv2 = DenseBlock3D(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = DenseBlock3D(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv4 = DenseBlock3D(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.center = DenseBlock3D(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) # attention blocks self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) # upsampling self.up_concat4 = UnetUp3_dense_CT(filters[4], filters[3], is_batchnorm) self.up_concat3 = UnetUp3_dense_CT(filters[3], filters[2], is_batchnorm) self.up_concat2 = UnetUp3_dense_CT(filters[2], filters[1], is_batchnorm) self.up_concat1 = UnetUp3_dense_CT(filters[1], filters[0], is_batchnorm) # deep supervision self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) # final conv (without any concat) self.final = nn.Conv3d(n_classes*4, n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, n_classes, feature_scale=4, is_deconv=True, in_channels=3, is_batchnorm=True, nonlocal_mode='embedded_gaussian', nonlocal_sf=4): super(unet_nonlocal_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1], inter_channels=filters[1] // 4, sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2], inter_channels=filters[2] // 4, sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) # upsampling self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv) self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv) self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv) self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, base_model, filters=512, n_classes=14): super(DebugCrossPatch3DTr, self).__init__() self.base_model = base_model # print(self.base_model) self.final_conv = nn.Conv3d(filters, n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, is_batchnorm=False, n=1) if is_deconv: self.up = nn.Sequential(nn.Upsample(scale_factor=2), RRCNN_block(in_size, out_size)) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('unetConv2') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) # initialise the blocks for m in self.children(): if m.__class__.__name__.find('unetConv2') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): super(UnetUp3, self).__init__() if is_deconv: self.conv = UnetConv3(in_size, out_size, is_batchnorm) self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) else: self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') # initialise the blocks for m in self.children(): if m.__class__.__name__.find('UnetConv3') != -1: continue init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): super(unet_grid_attention_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) self.gating = UnetGridGatingSignal3(filters[4], filters[3], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) # attention blocks self.attentionblock2 = GridAttentionBlock3D(in_channels=filters[1], gating_channels=filters[3], inter_channels=filters[1], sub_sample_factor=attention_dsample, mode=nonlocal_mode) self.attentionblock3 = GridAttentionBlock3D(in_channels=filters[2], gating_channels=filters[3], inter_channels=filters[2], sub_sample_factor=attention_dsample, mode=nonlocal_mode) self.attentionblock4 = GridAttentionBlock3D(in_channels=filters[3], gating_channels=filters[3], inter_channels=filters[3], sub_sample_factor=attention_dsample, mode=nonlocal_mode) # upsampling self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, self.is_batchnorm) self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, self.is_batchnorm) self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, self.is_batchnorm) self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, self.is_batchnorm) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, filters = [16, 32, 64, 128, 256], patch_size = [1,1,1], d_model = 256,n_classes=14, in_channels=1, n_cheads=2, n_sheads=8, bn = True, up_mode='deconv', n_strans=6, do_cross=False): super(CrossPatch3DTr, self).__init__() self.PE = None self.in_channels = in_channels self.filters = filters self.n_sheads = n_sheads self.d_model = d_model self.patch_size = patch_size self.do_cross = do_cross # CNN + Trans encoder self.encoder = SelfTransEncoder(filters=filters, patch_size=patch_size, d_model=d_model, in_channels=in_channels, n_sheads=n_sheads, bn=bn, n_strans=n_strans) # Transformer for cross attention # self.avgpool = nn.AvgPool3d((4,4,2), (4,4,2)) # self.positional_encoder = PositionalEncoding(self.d_model, dropout=0.1, max_len = 20000) self.p_enc_3d = PositionalEncodingPermute3D(filters[-1]) self.cross_trans = CrossAttention(self.d_model, n_cheads) # CNN decoder self.before_d_model = filters[4]*np.prod(self.patch_size) ## Rescale progressively feature map from cross attention # a = int(self.before_d_model/self.patch_size[0]) # b = int(a/self.patch_size[1]) # c = int(b/self.patch_size[2]) # self.center = nn.Sequential(nn.ConvTranspose3d(self.d_model, a, 2, stride=2), # nn.Conv3d(a,b, 3, padding=1), # nn.Conv3d(b,c, 3, padding=1)) ## Decode like 3D UNet self.up_concat4 = UnetUp3D(filters[4], filters[3], bn=bn, up_mode=up_mode) self.up_concat3 = UnetUp3D(filters[3], filters[2], bn=bn, up_mode=up_mode) self.up_concat2 = UnetUp3D(filters[2], filters[1], bn=bn, up_mode=up_mode) self.up_concat1 = UnetUp3D(filters[1], filters[0], bn=bn, up_mode=up_mode) self.final_conv = nn.Conv3d(filters[0], n_classes, 1) # Deep Supervision self.ds_cv1 = nn.Conv3d(filters[3], n_classes, 1) self.ds_cv2 = nn.Conv3d(filters[2], n_classes, 1) self.ds_cv3 = nn.Conv3d(filters[1], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=False, in_channels=3, is_batchnorm=True): super(cloudSegNet, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [16, 8, 8, 8, 8] # downsampling self.down1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=1) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.down2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=1) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.down3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=1) self.maxpool3 = nn.MaxPool2d(kernel_size=2) # upsampling self.up_concat3 = cloudSegNetUp(filters[3], filters[2], self.is_deconv, n=1) self.up_concat2 = cloudSegNetUp(filters[2], filters[1], self.is_deconv, n=1) self.up_concat1 = cloudSegNetUp(filters[1], filters[0], self.is_deconv, n=1) # final conv (without any concat) self.final = nn.Conv2d(filters[0], n_classes, kernel_size=5, padding=2) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type='kaiming')
def reinit_decoder(self): ## Decode like 3D UNet del self.up_concat4, self.up_concat3, self.up_concat2, self.up_concat1 self.up_concat4 = UnetUp3D(self.filters[4], self.filters[3], bn=self.bn, up_mode=self.up_mode) self.up_concat3 = UnetUp3D(self.filters[3], self.filters[2], bn=self.bn, up_mode=self.up_mode) self.up_concat2 = UnetUp3D(self.filters[2], self.filters[1], bn=self.bn, up_mode=self.up_mode) self.up_concat1 = UnetUp3D(self.filters[1], self.filters[0], bn=self.bn, up_mode=self.up_mode) del self.final_conv self.final_conv = nn.Conv3d(self.filters[0], self.n_classes, 1) # Deep Supervision del self.ds_cv1, self.ds_cv2, self.ds_cv3 self.ds_cv1 = nn.Conv3d(self.filters[3], self.n_classes, 1) self.ds_cv2 = nn.Conv3d(self.filters[2], self.n_classes, 1) self.ds_cv3 = nn.Conv3d(self.filters[1], self.n_classes, 1) init_weights(self.final_conv, init_type='kaiming') init_weights(self.ds_cv1, init_type='kaiming') init_weights(self.ds_cv2, init_type='kaiming') init_weights(self.ds_cv3, init_type='kaiming')
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): super(unet_3D, self).__init__() self.is_deconv = is_deconv self.in_channels = in_channels self.is_batchnorm = is_batchnorm self.feature_scale = feature_scale filters = [64, 128, 256, 512, 1024] filters = [int(x / self.feature_scale) for x in filters] # downsampling self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) # upsampling self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, is_batchnorm) self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, is_batchnorm) self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, is_batchnorm) self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, is_batchnorm) # final conv (without any concat) self.final = nn.Conv3d(filters[0], n_classes, 1) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming') elif isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')
def __init__(self, filters=[16, 32, 64, 128, 256, 512], use_trans=[1, 1, 1, 1, 1, 1], in_channels=1, n_sheads=8, bn=True, n_strans=6): super(SelfTransEncoder, self).__init__() self.in_channels = in_channels self.filters = filters self.n_sheads = n_sheads self.use_trans = use_trans # CNN encoder # self.first_conv = nn.Conv3d(self.in_channels, filters[0], 1) # self.conv1 = UNetConv3D(filters[0], filters[0], bn=bn) self.conv1 = UNetConv3D(self.in_channels, filters[0], bn=bn) if use_trans[0]: self.trans1 = SimpleTransEncoder(filters[0], n_sheads, n_strans) self.maxpool2 = nn.MaxPool3d(kernel_size=2) self.conv2 = UNetConv3D(filters[0], filters[1], bn=bn) if use_trans[1]: self.trans2 = SimpleTransEncoder(filters[1], n_sheads, n_strans) self.maxpool3 = nn.MaxPool3d(kernel_size=2) self.conv3 = UNetConv3D(filters[1], filters[2], bn=bn) if use_trans[2]: self.trans3 = SimpleTransEncoder(filters[2], n_sheads, n_strans) self.maxpool4 = nn.MaxPool3d(kernel_size=2) self.conv4 = UNetConv3D(filters[2], filters[3], bn=bn) if use_trans[3]: self.trans4 = SimpleTransEncoder(filters[3], n_sheads, n_strans) self.maxpool5 = nn.MaxPool3d(kernel_size=2) self.conv5 = UNetConv3D(filters[3], filters[4], bn=bn) if use_trans[4]: self.trans5 = SimpleTransEncoder(filters[4], n_sheads, n_strans) self.maxpool6 = nn.MaxPool3d(kernel_size=(2, 2, 1)) self.conv6 = UNetConv3D(filters[4], filters[5], bn=bn) if use_trans[5]: self.trans6 = SimpleTransEncoder(filters[5], n_sheads, n_strans) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, inChannels, outChannels, depth, downsample=True, groups=2): super(EncoderModule, self).__init__() self.downsample = downsample if downsample: self.conv = nn.Conv3d(inChannels, outChannels, 1) self.reversibleBlocks = makeReversibleComponent( outChannels, depth, groups) for m in self.children(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, inc, outc=[], kernel_size=3, padding=1, bias=None): super(DeformConv3D, self).__init__() self.kernel_size = kernel_size self.padding = padding #self.zero_padding = nn.functional.pad(padding) self.conv_kernel = nn.Conv3d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) # TODO: remove before uploading to GitHub # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')
def __init__(self, filters=[16, 32, 64, 128, 256], patch_size=[1, 1, 1], d_model=256, in_channels=1, n_sheads=8, bn=True, n_strans=6): super(SelfTransEncoder, self).__init__() self.in_channels = in_channels self.filters = filters self.n_sheads = n_sheads self.d_model = d_model self.patch_size = patch_size # CNN encoder # self.first_conv = nn.Conv3d(self.in_channels, filters[0], 1) # self.conv1 = UNetConv3D(filters[0], filters[0], bn=bn) self.conv1 = UNetConv3D(self.in_channels, filters[0], bn=bn) self.maxpool2 = nn.MaxPool3d(kernel_size=2) self.conv2 = UNetConv3D(filters[0], filters[1], bn=bn) self.maxpool3 = nn.MaxPool3d(kernel_size=2) self.conv3 = UNetConv3D(filters[1], filters[2], bn=bn) self.maxpool4 = nn.MaxPool3d(kernel_size=2) self.conv4 = UNetConv3D(filters[2], filters[3], bn=bn) self.maxpool5 = nn.MaxPool3d(kernel_size=2) self.conv5 = UNetConv3D(filters[3], filters[4], bn=bn) # Transformer for self attention self.before_d_model = filters[4] * np.prod(self.patch_size) self.linear = nn.Linear(self.before_d_model, self.d_model) # self.positional_encoder = PositionalEncoding(self.d_model, dropout=0.1, max_len = 1000) # self.p_enc_3d = PositionalEncodingPermute3D(filters[-1]) trans_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.n_sheads) self.self_trans = nn.TransformerEncoder(trans_layer, n_strans) # Feed Forward projection self.last = nn.Linear(self.d_model, self.before_d_model) # initialise weights for m in self.modules(): if isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True): super(UnetGridGatingSignal3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), nn.BatchNorm3d(out_size, track_running_stats=True), # OS nn.ReLU(inplace=True), ) else: self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), nn.ReLU(inplace=True), ) # initialise the blocks for m in self.children(): init_weights(m, init_type='kaiming')