def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): super(_AtrousSpatialPyramidPoolingModule, self).__init__() if output_stride == 8: rates = [2 * r for r in rates] elif output_stride == 16: pass else: raise 'output stride of {} not supported'.format(output_stride) self.features = [] # 1x1 self.features.append( nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))) # other rates for r in rates: self.features.append(nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=3, dilation=r, padding=r, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True) )) self.features = torch.nn.ModuleList(self.features) # img level features self.img_pooling = nn.AdaptiveAvgPool2d(1) self.img_conv = nn.Sequential( nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), Norm2d(reduction_dim), nn.ReLU(inplace=True))
def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), Norm2d(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample=downsample, stype='stage', baseWidth=self.baseWidth, scale=self.scale)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) return nn.Sequential(*layers)
def __init__(self, in_planes, dim=64, maxpool_size=8, avgpool_size=8, matcher_kernel_size=3, edge_points=64): super(PointFlowModuleWithMaxAvgpool, self).__init__() self.dim = dim self.point_matcher = PointMatcher(dim, matcher_kernel_size) self.down_h = nn.Conv2d(in_planes, dim, 1) self.down_l = nn.Conv2d(in_planes, dim, 1) self.softmax = nn.Softmax(dim=-1) self.maxpool_size = maxpool_size self.avgpool_size = avgpool_size self.edge_points = edge_points self.max_pool = nn.AdaptiveMaxPool2d((maxpool_size, maxpool_size), return_indices=True) self.avg_pool = nn.AdaptiveAvgPool2d((avgpool_size, avgpool_size)) self.edge_final = nn.Sequential( nn.Conv2d(in_channels=in_planes, out_channels=in_planes, kernel_size=3, padding=1, bias=False), Norm2d(in_planes), nn.ReLU(), nn.Conv2d(in_channels=in_planes, out_channels=1, kernel_size=3, padding=1, bias=False))
def __init__(self, num_classes, trunk='WideResnet38', criterion=None): super(DeepWV3Plus, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) if criterion is not None: try: checkpoint = torch.load('/mnt/lustre/share_data/lixiangtai/wider_resnet38_imagenet.pth', map_location='cpu') wide_resnet.load_state_dict(checkpoint) del checkpoint except: print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar.") raise RuntimeError("=====================Could not load ImageNet weights of WideResNet38 network.=======================") wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.final = nn.Sequential( nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final)
def _make_res_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: layers = [ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) ] if self.use_bn: layers += [Norm2d(planes * block.expansion)] downsample = nn.Sequential(*layers) layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers += [block(self.inplanes, planes)] return nn.Sequential(*layers)
def __init__(self, in_dim, middle_dim=256, node=32): super(SRHead, self).__init__() self.down = nn.Sequential(nn.Conv2d(in_dim, middle_dim, kernel_size=1, bias=False), Norm2d(middle_dim), nn.ReLU(inplace=True)) self.sr = ChannelReasonModule(middle_dim, middle_dim, node_num=node)
def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', skip='m1', skip_num=48): super(PSPNet, self).__init__() self.criterion = criterion self.variant = variant self.skip = skip self.skip_num = skip_num if trunk == 'resnet-50-deep': resnet = Resnet_Deep.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101-deep': resnet = Resnet_Deep.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: print("Not using Dilation ") self.ppm = PSPModule(2048, 256, norm_layer=Norm2d) if self.skip == 'm1': self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) elif self.skip == 'm2': self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) else: raise Exception('Not a valid skip') # body_edge module self.squeeze_body_edge = SqueezeBodyEdge(256, Norm2d) # fusion different edge part self.edge_fusion = nn.Conv2d(256 + 48, 256, 1, bias=False) self.sigmoid_edge = nn.Sigmoid() self.edge_out = nn.Sequential( nn.Conv2d(256, 48, kernel_size=3, padding=1, bias=False), Norm2d(48), nn.ReLU(inplace=True), nn.Conv2d(48, 1, kernel_size=1, bias=False)) # DSN for seg body part self.dsn_seg_body = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) # Final segmentation part self.final_seg = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
def __init__(self, num_classes, trunk='WideResnet38', criterion=None): super(DeepWV3PlusDecoupleEdgeBody, self).__init__() self.criterion = criterion logging.info("Trunk: %s", trunk) wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) wide_resnet = torch.nn.DataParallel(wide_resnet) if criterion is not None: try: checkpoint = torch.load( './pretrained_models/wider_resnet38.pth.tar', map_location='cpu') wide_resnet.load_state_dict(checkpoint) del checkpoint except: print( "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar." ) raise RuntimeError( "=====================Could not load ImageNet weights of WideResNet38 network.=======================" ) wide_resnet = wide_resnet.module self.mod1 = wide_resnet.mod1 self.mod2 = wide_resnet.mod2 self.mod3 = wide_resnet.mod3 self.mod4 = wide_resnet.mod4 self.mod5 = wide_resnet.mod5 self.mod6 = wide_resnet.mod6 self.mod7 = wide_resnet.mod7 self.pool2 = wide_resnet.pool2 self.pool3 = wide_resnet.pool3 del wide_resnet self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, output_stride=8) self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) edge_dim = 256 self.edge_conv = nn.Sequential( nn.Conv2d(1, edge_dim, kernel_size=1, bias=False), Norm2d(edge_dim), nn.ReLU(inplace=True)) self.squeeze_body_edge = SqueezeBodyEdge(256, Norm2d) # fusion different edges self.edge_fusion = nn.Conv2d(256 + 48, 256, 1, bias=False) self.sigmoid_edge = nn.Sigmoid() self.edge_out = nn.Sequential( nn.Conv2d(256, 48, kernel_size=3, padding=1, bias=False), Norm2d(48), nn.ReLU(inplace=True), nn.Conv2d(48, 1, kernel_size=1, bias=False)) self.dsn_seg_body = nn.Sequential( nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) self.final_seg = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.final_seg, self.dsn_seg_body)
def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', skip='m1', skip_num=48): super(DeepFCN, self).__init__() self.criterion = criterion self.variant = variant self.skip = skip self.skip_num = skip_num if trunk == 'resnet-50-deep': resnet = Resnet_Deep.resnet50() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk == 'resnet-101-deep': resnet = Resnet_Deep.resnet101() resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) else: raise ValueError("Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if self.variant == 'D': for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif self.variant == 'D16': for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: print("Not using Dilation ") self.fcn_head = nn.Sequential( nn.Conv2d(2048, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), ) if self.skip == 'm1': self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) elif self.skip == 'm2': self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) else: raise Exception('Not a valid skip') self.final = nn.Sequential( nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), Norm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) initialize_weights(self.fcn_head) initialize_weights(self.bot_fine) initialize_weights(self.final)