예제 #1
0
def main():
    args = get_arguments()
    print("=====> Configure dataset and model")
    configure_dataset_model(args)
    print(args)

    print("=====> Set GPU for training")
    if args.cuda:
        print("====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")
    model = CoattentionNet(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from,
                                  map_location=lambda storage, loc: storage)
    #print(saved_state_dict.keys())
    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})
    model.load_state_dict(convert_state_dict(saved_state_dict["model"])
                          )  #convert_state_dict(saved_state_dict["model"])

    model.eval()
    model.cuda()
    if args.dataset == 'voc12':
        testloader = data.DataLoader(VOCDataTestSet(args.data_dir,
                                                    args.data_list,
                                                    crop_size=(505, 505),
                                                    mean=args.img_mean),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
        voc_colorize = VOCColorize()

    elif args.dataset == 'cityscapes':
        testloader = data.DataLoader(
            CityscapesTestDataSet(args.data_dir,
                                  args.data_list,
                                  f_scale=args.f_scale,
                                  mean=args.img_mean),
            batch_size=1,
            shuffle=False,
            pin_memory=True
        )  # f_sale, meaning resize image at f_scale as input
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')  #size = (h,w)
        voc_colorize = VOCColorize()

    elif args.dataset == 'davis':  #for davis 2016
        db_test = db.PairwiseImg(
            train=False,
            inputRes=(473, 473),
            db_root_dir=args.data_dir,
            transform=None,
            seq_name=None,
            sample_range=args.sample_range
        )  #db_root_dir() --> '/path/to/DAVIS-2016' train path
        testloader = data.DataLoader(db_test,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=0)
        voc_colorize = VOCColorize()
    else:
        print("dataset error")

    data_list = []

    if args.save_segimage:
        if not os.path.exists(args.seg_save_dir) and not os.path.exists(
                args.vis_save_dir):
            os.makedirs(args.seg_save_dir)
            os.makedirs(args.vis_save_dir)
    print("======> test set size:", len(testloader))
    my_index = 0
    old_temp = ''
    for index, batch in enumerate(testloader):
        print('%d processd' % (index))
        target = batch['target']
        #search = batch['search']
        temp = batch['seq_name']
        args.seq_name = temp[0]
        print(args.seq_name)
        if old_temp == args.seq_name:
            my_index = my_index + 1
        else:
            my_index = 0
        output_sum = 0
        for i in range(0, args.sample_range):
            search = batch['search' + '_' + str(i)]
            search_im = search
            #print(search_im.size())
            output = model(
                Variable(target, volatile=True).cuda(),
                Variable(search_im, volatile=True).cuda())
            #print(output[0]) # output有两个
            output_sum = output_sum + output[0].data[
                0, 0].cpu().numpy()  #分割那个分支的结果
            #np.save('infer'+str(i)+'.npy',output1)
            #output2 = output[1].data[0, 0].cpu().numpy() #interp'

        output1 = output_sum / args.sample_range

        first_image = np.array(
            Image.open(args.data_dir + '/JPEGImages/480p/blackswan/00000.jpg'))
        original_shape = first_image.shape
        output1 = cv2.resize(output1, (original_shape[1], original_shape[0]))
        if 0:
            original_image = target[0]
            #print('image type:',type(original_image.numpy()))
            original_image = original_image.numpy()
            original_image = original_image.transpose((2, 1, 0))
            original_image = cv2.resize(original_image,
                                        (original_shape[1], original_shape[0]))
            unary = np.zeros((2, original_shape[0] * original_shape[1]),
                             dtype='float32')
            #unary[0, :, :] = res_saliency/255
            #unary[1, :, :] = 1-res_saliency/255
            EPSILON = 1e-8
            tau = 1.05

            crf = dcrf.DenseCRF(original_shape[1] * original_shape[0], 2)

            anno_norm = (output1 - np.min(output1)) / (
                np.max(output1) - np.min(output1))  #res_saliency/ 255.
            n_energy = 1.0 - anno_norm + EPSILON  #-np.log((1.0 - anno_norm + EPSILON)) #/ (tau * sigmoid(1 - anno_norm))
            p_energy = anno_norm + EPSILON  #-np.log(anno_norm + EPSILON) #/ (tau * sigmoid(anno_norm))

            #unary = unary.reshape((2, -1))
            #print(unary.shape)
            unary[1, :] = p_energy.flatten()
            unary[0, :] = n_energy.flatten()

            crf.setUnaryEnergy(unary_from_softmax(unary))

            feats = create_pairwise_gaussian(sdims=(3, 3),
                                             shape=original_shape[:2])

            crf.addPairwiseEnergy(feats,
                                  compat=3,
                                  kernel=dcrf.DIAG_KERNEL,
                                  normalization=dcrf.NORMALIZE_SYMMETRIC)

            feats = create_pairwise_bilateral(
                sdims=(10, 10),
                schan=(1, 1, 1),  # orgin is 60, 60 5, 5, 5
                img=original_image,
                chdim=2)
            crf.addPairwiseEnergy(feats,
                                  compat=5,
                                  kernel=dcrf.DIAG_KERNEL,
                                  normalization=dcrf.NORMALIZE_SYMMETRIC)

            Q = crf.inference(5)
            MAP = np.argmax(Q, axis=0)
            output1 = MAP.reshape((original_shape[0], original_shape[1]))

        mask = (output1 * 255).astype(np.uint8)
        #print(mask.shape[0])
        mask = Image.fromarray(mask)

        if args.dataset == 'voc12':
            print(output.shape)
            print(size)
            output = output[:, :size[0], :size[1]]
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            if args.save_segimage:
                seg_filename = os.path.join(args.seg_save_dir,
                                            '{}.png'.format(name[0]))
                color_file = Image.fromarray(
                    voc_colorize(output).transpose(1, 2, 0), 'RGB')
                color_file.save(seg_filename)

        elif args.dataset == 'davis':

            save_dir_res = os.path.join(args.seg_save_dir, 'Results',
                                        args.seq_name)
            old_temp = args.seq_name
            if not os.path.exists(save_dir_res):
                os.makedirs(save_dir_res)
            if args.save_segimage:
                my_index1 = str(my_index).zfill(5)
                seg_filename = os.path.join(save_dir_res,
                                            '{}.png'.format(my_index1))
                #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')
                mask.save(seg_filename)
                #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0)
                #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True)

        elif args.dataset == 'cityscapes':
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            if args.save_segimage:
                output_color = cityscapes_colorize_mask(output)
                output = Image.fromarray(output)
                output.save('%s/%s.png' % (args.seg_save_dir, name[0]))
                output_color.save('%s/%s_color.png' %
                                  (args.seg_save_dir, name[0]))
        else:
            print("dataset error")
예제 #2
0
def main():
    
    
    print("=====> Configure dataset and pretrained model")
    configure_dataset_init_model(args)
    print(args)

    print("    current dataset:  ", args.dataset)
    print("    init model: ", args.restore_from)
    print("=====> Set GPU for training")
    if args.cuda:
        print("====> Use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
    # Select which GPU, -1 if CPU
    #gpu_id = args.gpus
    #device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
    print("=====> Random Seed: ", args.random_seed)
    torch.manual_seed(args.random_seed)
    if args.cuda:
        torch.cuda.manual_seed(args.random_seed) 

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    print("=====> Building network")
    saved_state_dict = torch.load(args.restore_from)
    # saved_state_dict = torch.load(args.restore_from, map_location='cpu') ####
    model = CoattentionNet(num_classes=args.num_classes)
    #print(model)
    new_params = model.state_dict().copy()
    for i in saved_state_dict["model"]:
        #Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.') # 针对多GPU的情况
        #i_parts.pop(1)
        #print('i_parts:  ', '.'.join(i_parts[1:-1]))
        #if  not i_parts[1]=='main_classifier': #and not '.'.join(i_parts[1:-1]) == 'layer5.bottleneck' and not '.'.join(i_parts[1:-1]) == 'layer5.bn':  #init model pretrained on COCO, class name=21, layer5 is ASPP
        new_params['encoder'+'.'+'.'.join(i_parts[1:])] = saved_state_dict["model"][i]
            #print('copy {}'.format('.'.join(i_parts[1:])))
    
   
    print("=====> Loading init weights,  pretrained COCO for VOC2012, and pretrained Coarse cityscapes for cityscapes")
 
            
    model.load_state_dict(new_params) #只用到resnet的第5个卷积层的参数
    #print(model.keys())
    if args.cuda:
        #model.to(device)
        if torch.cuda.device_count()>1:
            print("torch.cuda.device_count()=",torch.cuda.device_count())
            model = torch.nn.DataParallel(model).cuda()  #multi-card data parallel
        else:
            print("single GPU for training")
            model = model.cuda()  #1-card data parallel
    start_epoch=0
    
    print("=====> Whether resuming from a checkpoint, for continuing training")
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint["epoch"] 
            model.load_state_dict(checkpoint["model"])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    model.train()
    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    
    print('=====> Computing network parameters')
    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))
 
    print("=====> Preparing training data")
    if args.dataset == 'voc12':
        trainloader = data.DataLoader(VOCDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size= args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'cityscapes':
        trainloader = data.DataLoader(CityscapesDataSet(args.data_dir, args.data_list, max_iters=None, crop_size=input_size, 
                                                 scale=args.random_scale, mirror=args.random_mirror, mean=args.img_mean), 
                                      batch_size = args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    elif args.dataset == 'davis':  #for davis 2016
        db_train = db.PairwiseImg(train=True, inputRes=input_size, db_root_dir=args.data_dir, img_root_dir=args.img_dir,  transform=None) #db_root_dir() --> '/path/to/DAVIS-2016' train path
        # trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0)
        trainloader = data.DataLoader(db_train, batch_size= args.batch_size, shuffle=True, num_workers=0, drop_last=True) ####
    else:
        print("dataset error")

    optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': 1*args.learning_rate },  #针对特定层进行学习,有些层不学习
                {'params': get_10x_lr_params(model), 'lr': 10*args.learning_rate}], 
                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()


    
    logFileLoc = args.snapshot_dir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t\t%s" % ('iter', 'Loss(train)\n'))
    logger.flush()

    print("=====> Begin to train")
    train_len=len(trainloader)
    print("  iteration numbers  of per epoch: ", train_len)
    print("  epoch num: ", args.maxEpoches)
    print("  max iteration: ", args.maxEpoches*train_len)
    
    for epoch in range(start_epoch, int(args.maxEpoches)):
        
        np.random.seed(args.random_seed + epoch)
        for i_iter, batch in enumerate(trainloader,0): #i_iter from 0 to len-1
            #print("i_iter=", i_iter, "epoch=", epoch)
            target, target_gt, search, search_gt = batch['target'], batch['target_gt'], batch['search'], batch['search_gt']
            images, labels = batch['img'], batch['img_gt']
            #print(labels.size())
            images.requires_grad_()
            images = Variable(images).cuda()
            labels = Variable(labels.float().unsqueeze(1)).cuda()
            
            target.requires_grad_()
            target = Variable(target).cuda()
            target_gt = Variable(target_gt.float().unsqueeze(1)).cuda()
            
            search.requires_grad_()
            search = Variable(search).cuda()
            search_gt = Variable(search_gt.float().unsqueeze(1)).cuda()
            
            optimizer.zero_grad()
            
            lr = adjust_learning_rate(optimizer, i_iter+epoch*train_len, epoch,
                    max_iter = args.maxEpoches * train_len)
            #print(images.size())
            if i_iter%3 ==0: #对于静态图片的训练
                
                pred1, pred2, pred3 = model(images, images)
                loss = 0.1*(loss_calc1(pred3, labels) + 0.8* loss_calc2(pred3, labels) )
                loss.backward()
                
            else:
                    
                pred1, pred2, pred3 = model(target, search)
                loss = loss_calc1(pred1, target_gt) + 0.8* loss_calc2(pred1, target_gt) + loss_calc1(pred2, search_gt) + 0.8* loss_calc2(pred2, search_gt)#class_balanced_cross_entropy_loss(pred, labels, size_average=False)
                loss.backward()
            
            optimizer.step()
                
            print("===> Epoch[{}]({}/{}): Loss: {:.10f}  lr: {:.5f}".format(epoch, i_iter, train_len, loss.data, lr))
            logger.write("Epoch[{}]({}/{}):     Loss: {:.10f}      lr: {:.5f}\n".format(epoch, i_iter, train_len, loss.data, lr))
            logger.flush()
                
        print("=====> saving model")
        state={"epoch": epoch+1, "model": model.state_dict()}
        torch.save(state, osp.join(args.snapshot_dir, 'co_attention_'+str(args.dataset)+"_"+str(epoch)+'.pth'))


    end = timeit.default_timer()
    print( float(end-start)/3600, 'h')
    logger.write("total training time: {:.2f} h\n".format(float(end-start)/3600))
    logger.close()
예제 #3
0
def main():
    args = get_arguments()
    print("=====> Configure dataset and model")
    configure_dataset_model(args)
    print(args)
    model = CoattentionNet(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from,
                                  map_location=lambda storage, loc: storage)
    #print(saved_state_dict.keys())
    #model.load_state_dict({k.replace('pspmodule.',''):v for k,v in torch.load(args.restore_from)['state_dict'].items()})
    model.load_state_dict(convert_state_dict(saved_state_dict["model"])
                          )  #convert_state_dict(saved_state_dict["model"])

    model.eval()
    model.cuda()
    if args.dataset == 'voc12':
        testloader = data.DataLoader(VOCDataTestSet(args.data_dir,
                                                    args.data_list,
                                                    crop_size=(505, 505),
                                                    mean=args.img_mean),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)
        interp = nn.Upsample(size=(505, 505), mode='bilinear')
        voc_colorize = VOCColorize()

    elif args.dataset == 'davis':  #for davis 2016
        db_test = db.PairwiseImg(
            train=False,
            inputRes=(473, 473),
            db_root_dir=args.data_dir,
            transform=None,
            seq_name=None,
            sample_range=args.sample_range
        )  #db_root_dir() --> '/path/to/DAVIS-2016' train path
        testloader = data.DataLoader(db_test,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=0)
        #voc_colorize = VOCColorize()
    else:
        print("dataset error")

    data_list = []

    if args.save_segimage:
        if not os.path.exists(args.seg_save_dir) and not os.path.exists(
                args.vis_save_dir):
            os.makedirs(args.seg_save_dir)
            os.makedirs(args.vis_save_dir)
    print("======> test set size:", len(testloader))
    my_index = 0
    old_temp = ''
    for index, batch in enumerate(testloader):
        print('%d processd' % (index))
        target = batch['target']
        #search = batch['search']
        temp = batch['seq_name']
        args.seq_name = temp[0]
        print(args.seq_name)
        if old_temp == args.seq_name:
            my_index = my_index + 1
        else:
            my_index = 0
        output_sum = 0
        for i in range(0, args.sample_range):
            search = batch['search' + '_' + str(i)]
            search_im = search
            #print(search_im.size())
            output = model(
                Variable(target, volatile=True).cuda(),
                Variable(search_im, volatile=True).cuda())
            #print(output[0]) # output有两个
            output_sum = output_sum + output[0].data[
                0, 0].cpu().numpy()  #分割那个分支的结果
            #np.save('infer'+str(i)+'.npy',output1)
            #output2 = output[1].data[0, 0].cpu().numpy() #interp'

        output1 = output_sum / args.sample_range

        first_image = np.array(
            Image.open(args.data_dir + '/JPEGImages/480p/blackswan/00000.jpg'))
        original_shape = first_image.shape
        output1 = cv2.resize(output1, (original_shape[1], original_shape[0]))

        mask = (output1 * 255).astype(np.uint8)
        #print(mask.shape[0])
        mask = Image.fromarray(mask)

        if args.dataset == 'voc12':
            print(output.shape)
            print(size)
            output = output[:, :size[0], :size[1]]
            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            if args.save_segimage:
                seg_filename = os.path.join(args.seg_save_dir,
                                            '{}.png'.format(name[0]))
                color_file = Image.fromarray(
                    voc_colorize(output).transpose(1, 2, 0), 'RGB')
                color_file.save(seg_filename)

        elif args.dataset == 'davis':

            save_dir_res = os.path.join(args.seg_save_dir, 'Results',
                                        args.seq_name)
            old_temp = args.seq_name
            if not os.path.exists(save_dir_res):
                os.makedirs(save_dir_res)
            if args.save_segimage:
                my_index1 = str(my_index).zfill(5)
                seg_filename = os.path.join(save_dir_res,
                                            '{}.png'.format(my_index1))
                #color_file = Image.fromarray(voc_colorize(output).transpose(1, 2, 0), 'RGB')
                mask.save(seg_filename)
                #np.concatenate((torch.zeros(1, 473, 473), mask, torch.zeros(1, 512, 512)),axis = 0)
                #save_image(output1 * 0.8 + target.data, args.vis_save_dir, normalize=True)
        else:
            print("dataset error")