Пример #1
0
    def __init__(self, num_classes, trunk=None, criterion=None):
        
        super(GSCNN, self).__init__()
        self.criterion = criterion
        self.num_classes = num_classes

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        
        wide_resnet = wide_resnet.module
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        self.interpolate = F.interpolate
        del wide_resnet

        self.dsn1 = nn.Conv2d(64, 1, 1)
        self.dsn3 = nn.Conv2d(256, 1, 1)
        self.dsn4 = nn.Conv2d(512, 1, 1)
        self.dsn7 = nn.Conv2d(4096, 1, 1)

        self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None)
        self.d1 = nn.Conv2d(64, 32, 1)
        self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = nn.Conv2d(32, 16, 1)
        self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = nn.Conv2d(16, 8, 1)
        self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False)

        self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = gsc.GatedSpatialConv2d(32, 32)
        self.gate2 = gsc.GatedSpatialConv2d(16, 16)
        self.gate3 = gsc.GatedSpatialConv2d(8, 8)
         
        self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False)

        self.final_seg = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
        initialize_weights(self.final_seg)
Пример #2
0
    def __init__(self,
                 num_classes,
                 input_size,
                 trunk='WideResnet38',
                 criterion=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        logging.info("Trunk: %s", trunk)
        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        # TODO: Should this be even here ?
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        try:
            checkpoint = torch.load('weights/ResNet/wider_resnet38.pth.tar',
                                    map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print(
                "=====================Could not load imagenet weights======================="
            )

        wide_resnet = wide_resnet.module

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        initialize_weights(self.final)
Пример #3
0
    def __init__(self, num_classes, trunk='WideResnet38', criterion=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        logging.info("Trunk: %s", trunk)

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        if criterion is not None:
            try:
                checkpoint = torch.load(
                    './pretrained_models/wider_resnet38.pth.tar',
                    map_location='cpu')
                wide_resnet.load_state_dict(checkpoint['state_dict'])
                del checkpoint
            except:
                print(
                    "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
                )
                raise RuntimeError(
                    "=====================Could not load ImageNet weights of WideResNet38 network.======================="
                )
        wide_resnet = wide_resnet.module

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        initialize_weights(self.final)
    def __init__(self, num_classes, trunk='WideResnet38', criterion=None,ngf=64, input_nc=1, output_nc=2):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        self.MSEcriterion= torch.nn.MSELoss()
        logging.info("Trunk: %s", trunk)
        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        try:
            checkpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/semantic-segmentation-master/pretrained_models/wider_resnet38.pth.tar', map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print("=====================Could not load ImageNet weights=======================")
            print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models.")

        #audio_unet = AudioNet_multitask(ngf=64,input_nc=2)
        #Acheckpoint = torch.load('/srv/beegfs02/scratch/language_vision/data/Sound_Event_Prediction/audio/audioSynthesis/checkpoints/synBi2Bi_16_25/3_audio.pth', map_location='cpu')
        #pretrained_dict = Acheckpoint
        #model_dict = audio_unet.state_dict()
        #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k!='audionet_upconvlayer1.0.weight' and k!='audionet_upconvlayer5.0.weight' and k!='audionet_upconvlayer5.0.bias' and k!='conv1x1.0.weight' and k!='conv1x1.0.bias' and k!='conv1x1.1.weight' and k!='conv1x1.1.bias' and k!='conv1x1.1.running_mean' and k!='conv1x1.1.running_var'}
        #model_dict.update(pretrained_dict) 
        #audio_unet.load_state_dict(model_dict)

        self.audionet_convlayer1 = unet_conv(input_nc, ngf)
        self.audionet_convlayer2 = unet_conv(ngf, ngf * 2)
        self.audionet_convlayer3 = unet_conv(ngf * 2, ngf * 4)
        self.audionet_convlayer4 = unet_conv(ngf * 4, ngf * 8)
        self.audionet_convlayer5 = unet_conv(ngf * 8, ngf * 8)
        self.audionet_upconvlayer1 = unet_upconv(1024, ngf * 8) #1296 (audio-visual feature) = 784 (visual feature) + 512 (audio feature)
        self.audionet_upconvlayer2 = unet_upconv(ngf * 8, ngf *4)
        self.audionet_upconvlayer3 = unet_upconv(ngf * 4, ngf * 2)
        self.audionet_upconvlayer4 = unet_upconv(ngf * 2, ngf)
        self.audionet_upconvlayer5 = unet_upconv(ngf  , output_nc, True) #outermost layer use a sigmoid to bound the mask
        self.conv1x1 = create_conv(4096, 2, 1, 0)

        wide_resnet = wide_resnet.module
        #self.unet= audio_unet
        #print(wide_resnet)
        '''
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        '''
        del wide_resnet
        
        self.aspp = _AtrousSpatialPyramidPoolingModule(512, 64,
                                                       output_stride=8)
        self.depthaspp = _AtrousSpatialPyramidPoolingModule(512,64,
                                                       output_stride=8)

        self.bot_aud1 = nn.Conv2d(512, 256, kernel_size=1, bias=False)
        self.bot_multiaud = nn.Conv2d(512, 512, kernel_size=1, bias=False)
        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(320, 256, kernel_size=1, bias=False)
        self.bot_depthaspp = nn.Conv2d(320, 128, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
        self.final_depth = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            Norm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            Norm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, kernel_size=1, bias=False))

        initialize_weights(self.final);initialize_weights(self.bot_aud1);initialize_weights(self.bot_multiaud);
Пример #5
0
    def __init__(self, trunk='WideResnet38', criterion=None, criterion2=None):

        super(DeepWV3Plus, self).__init__()
        self.criterion = criterion
        self.criterion2 = criterion2
        logging.info("Trunk: %s", trunk)

        tasks = ['semantic', 'traversability']
        wide_resnet = wider_resnet38_a2(classes=1000,
                                        dilation=True,
                                        tasks=tasks)
        wide_resnet = torch.nn.DataParallel(wide_resnet)
        if criterion is not None:
            try:
                checkpoint = torch.load(
                    './pretrained_models/wider_resnet38.pth.tar',
                    map_location='cpu')
                #                wide_resnet.load_state_dict(checkpoint['state_dict'])

                net_state_dict = wide_resnet.state_dict()
                loaded_dict = checkpoint['state_dict']
                new_loaded_dict = {}
                for k in net_state_dict:
                    if k in loaded_dict and net_state_dict[k].size(
                    ) == loaded_dict[k].size():
                        new_loaded_dict[k] = loaded_dict[k]
                    else:
                        logging.info("Skipped loading parameter %s", k)
                net_state_dict.update(new_loaded_dict)
                wide_resnet.load_state_dict(net_state_dict)

                del checkpoint
            except:
                print(
                    "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
                )
                raise RuntimeError(
                    "=====================Could not load ImageNet weights of WideResNet38 network.======================="
                )
        wide_resnet = wide_resnet.module

        self.task_weights = torch.nn.Parameter(
            torch.zeros(2, requires_grad=True))

        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        del wide_resnet

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 19, kernel_size=1, bias=False))

        self.aspp2 = _AtrousSpatialPyramidPoolingModule(4096,
                                                        256,
                                                        output_stride=8)

        self.bot_fine2 = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp2 = nn.Conv2d(1280, 256, kernel_size=1, bias=False)

        self.final2 = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 2, kernel_size=1, bias=False))
Пример #6
0
    def __init__(self, num_classes, trunk=None, criterion=None):

        super(GSCNN, self).__init__()
        self.criterion = criterion
        self.num_classes = num_classes

        wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
        wide_resnet = torch.nn.DataParallel(wide_resnet)

        try:
            checkpoint = torch.load(
                './network/pretrained_models/wider_resnet38.pth.tar',
                map_location='cpu')
            wide_resnet.load_state_dict(checkpoint['state_dict'])
            del checkpoint
        except:
            print(
                "Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar."
            )
            raise RuntimeError(
                "=====================Could not load ImageNet weights of WideResNet38 network.======================="
            )

        wide_resnet = wide_resnet.module
        self.mod1 = wide_resnet.mod1
        self.mod2 = wide_resnet.mod2
        self.mod3 = wide_resnet.mod3
        self.mod4 = wide_resnet.mod4
        self.mod5 = wide_resnet.mod5
        self.mod6 = wide_resnet.mod6
        self.mod7 = wide_resnet.mod7
        self.pool2 = wide_resnet.pool2
        self.pool3 = wide_resnet.pool3
        self.interpolate = F.interpolate
        del wide_resnet

        self.dsn1 = nn.Conv2d(64, 1, 1)
        self.dsn3 = nn.Conv2d(256, 1, 1)
        self.dsn4 = nn.Conv2d(512, 1, 1)
        self.dsn7 = nn.Conv2d(4096, 1, 1)

        self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None)
        self.d1 = nn.Conv2d(64, 32, 1)
        self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None)
        self.d2 = nn.Conv2d(32, 16, 1)
        self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None)
        self.d3 = nn.Conv2d(16, 8, 1)
        self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False)

        self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False)

        self.gate1 = gsc.GatedSpatialConv2d(32, 32)
        self.gate2 = gsc.GatedSpatialConv2d(16, 16)
        self.gate3 = gsc.GatedSpatialConv2d(8, 8)

        self.aspp = _AtrousSpatialPyramidPoolingModule(4096,
                                                       256,
                                                       output_stride=8)

        self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
        self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False)

        self.final_seg = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            Norm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
        initialize_weights(self.final_seg)