def __init__(self,
                 use_gpu=True,
                 time_budget=None,
                 save_result_dir=cfg.SAVE_RESULT_DIR,
                 pretrained=True,
                 interactive_test=False,
                 freeze_bn=False):

        self.save_res_dir = save_result_dir
        self.time_budget = time_budget
        self.feature_extracter = DeepLab(backbone='resnet',
                                         freeze_bn=freeze_bn)
        #        self.feature_extracter= deeplabv3plus(cfg)

        if pretrained:
            pretrained_dict = torch.load(cfg.PRETRAINED_MODEL)
            pretrained_dict = pretrained_dict['state_dict']
            self.load_network(self.feature_extracter, pretrained_dict)
            print('load pretrained model successfully.')
        self.model = IntVOS(cfg, self.feature_extracter, freeze_bn=freeze_bn)
        print(self.model)
        #        self.model.freeze_bn()
        pd = torch.load(
            os.path.join(cfg.SAVE_VOS_RESULT_DIR, 'save_step_100000.pth'))
        self.load_network(self.model, pd)
        self.use_gpu = use_gpu
        if use_gpu:
            self.model = self.model.cuda()
예제 #2
0
    def __init__(self, use_gpu=True,time_budget=None,
        save_result_dir='./saved_model_total_ytb_davis',pretrained=True,interactive_test=False,freeze_bn=False):

        self.save_res_dir = save_result_dir
        self.time_budget=time_budget
        self.feature_extracter = DeepLab(backbone='resnet',freeze_bn=freeze_bn)
#        self.feature_extracter= deeplabv3plus(cfg)

        if pretrained:
            pretrained_dict = torch.load(cfg.PRETRAINED_MODEL)
            pretrained_dict = pretrained_dict['state_dict']
            self.load_network(self.feature_extracter,pretrained_dict)
            print('load pretrained model successfully.')
        self.model = IntVOS(cfg,self.feature_extracter,freeze_bn=freeze_bn)
#        self.model.freeze_bn()
        self.use_gpu=use_gpu
        if use_gpu:
            self.model = self.model.cuda()
예제 #3
0
def get_model(args):
    if args.network_name == "FPN":
        model = FPNSeg(args)

        if args.weight_type == "moco_v2":
            assert args.n_layers == 50, args.n_layers
            # path to moco_v2 weights. Current path is relative to scripts dir
            try:
                self_sup_weights = torch.load("../networks/backbones/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
            except FileNotFoundError:
                self_sup_weights = torch.load("../networks/backbones/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]

            model_state_dict = model.encoder.state_dict()

            for k in list(self_sup_weights.keys()):
                if k.replace("module.encoder_q.", '') in ["fc.0.weight", "fc.0.bias", "fc.2.weight", "fc.2.bias"]:
                    self_sup_weights.pop(k)
            for name, param in self_sup_weights.items():
                name = name.replace("encoder_q.", '').replace("module", 'base')

                if name.replace("base.", '') in ["conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean",
                                                 "bn1.running_var", "bn1.num_batches_tracked"]:
                    name = name.replace("base", "base.prefix")

                if name not in model_state_dict:
                    print(f"{name} is not applied!")
                    continue

                if isinstance(param, torch.nn.Parameter):
                    param = param.data
                model_state_dict[name].copy_(param)
            model.encoder.load_state_dict(model_state_dict)
            print("moco_v2 weights are loaded successfully.")

    elif args.network_name == "deeplab":
        model = DeepLab(args)
    return model
예제 #4
0
print('[%s] Loading data' % (datetime.datetime.now()))
# augmenters
train_xtransform_us, train_ytransform_us, test_xtransform_us, test_ytransform_us = get_augmenters_2d(augment_noise=(args.augment_noise==1))
train_xtransform, train_ytransform, test_xtransform, test_ytransform = get_augmenters_2d(augment_noise=(args.augment_noise==1))
# load data
train = StronglyLabeledVolumeDataset(args.data_train, args.labels_train, input_shape, transform=train_xtransform, target_transform=train_ytransform, preprocess=args.preprocess)
test = StronglyLabeledVolumeDataset(args.data_test, args.labels_test, input_shape, transform=test_xtransform, target_transform=test_ytransform, preprocess=args.preprocess)
train_loader = DataLoader(train, batch_size=args.train_batch_size)
test_loader = DataLoader(test, batch_size=args.test_batch_size)

"""
    Setup optimization for finetuning
"""
print('[%s] Setting up optimization for finetuning' % (datetime.datetime.now()))
# load best checkpoint
net = DeepLab()

"""
    Train the network
"""
print('[%s] Training network' % (datetime.datetime.now()))
net.train_net(train_loader=train_loader, test_loader=test_loader,
              loss_fn=loss_fn_seg, lr=args.lr, step_size=args.step_size, gamma=args.gamma,
              epochs=args.epochs, test_freq=args.test_freq, print_stats=args.print_stats,
              log_dir=args.log_dir)

"""
    Validate the trained network
"""
print('[%s] Validating the trained network' % (datetime.datetime.now()))
test_data = test.data
예제 #5
0
def main():
    total_frame_num_dic={}
    #################
    seqs = []
    with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f:
        seqs_tmp = f.readlines()
        seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
        seqs.extend(seqs_tmp)
    for seq_name in seqs:
        images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip())))
        total_frame_num_dic[seq_name]=len(images)
    _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017',
                                          'v_a_l' + '_instances.txt')
    seq_dict = json.load(open(_seq_list_file, 'r'))

    ##################

    is_save_image=True
    report_save_dir= cfg.RESULT_ROOT
    save_res_dir = './saved_model_inter_net_new_re'
        # Configuration used in the challenges
    max_nb_interactions = 8  # Maximum number of interactions
    max_time_per_interaction = 10000  # Maximum time per interaction per object
    # Total time available to interact with a sequence and an initial set of scribbles
    max_time = max_nb_interactions * max_time_per_interaction  # Maximum time per object
    # Interactive parameters
    subset = 'val'
    host = 'localhost'  # 'localhost' for subsets train and val.

    feature_extracter = DeepLab(backbone='resnet',freeze_bn=False)
    model = IntVOS(cfg,feature_extracter)
    model= model.cuda()
    print('model loading...')

    saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth')
    pretrained_dict = torch.load(saved_model_dict)
    load_network(model,pretrained_dict)

    print('model loading finished!')
    model.eval()

    seen_seq={}
    with torch.no_grad():
        with DavisInteractiveSession(host=host,
                                davis_root=cfg.DATA_ROOT,
                                subset=subset,
                                report_save_dir=report_save_dir,
                                max_nb_interactions=max_nb_interactions,
                                max_time=max_time,
                                metric_to_optimize='J'
                                                            
                                ) as sess:
            while sess.next():
                t_total=timeit.default_timer()
                # Get the current iteration scribbles

                sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True)
                print(sequence)
                start_annotated_frame=annotated_frames(scribbles)[0]
                
                pred_masks=[]
                pred_masks_reverse=[]


                if first_scribble:
                    anno_frame_list=[]
                    n_interaction=1
                    eval_global_map_tmp_dic={}
                    local_map_dics=({},{})
                    total_frame_num=total_frame_num_dic[sequence]
                    obj_nums = seq_dict[sequence][-1]
                    eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(),
                                                            seq_name=sequence)

                else:
                    n_interaction+=1



                ##########################Reference image process
                scr_f = start_annotated_frame
                anno_frame_list.append(start_annotated_frame)
                print(start_annotated_frame)
                scr_f = str(scr_f)
                while len(scr_f)!=5:
                    scr_f='0'+scr_f
                ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg')
                ref_img = cv2.imread(ref_img)
                h_,w_ = ref_img.shape[:2]
                ref_img = np.array(ref_img,dtype=np.float32)

                #ref_img = tr.ToTensor()(ref_img)
                #ref_img = ref_img.unsqueeze(0)
                scribble_masks=scribbles2mask(scribbles,(h_,w_))
                scribble_label=scribble_masks[start_annotated_frame]
                sample = {'ref_img':ref_img,'scribble_label':scribble_label}
                sample = tr.ToTensor()(sample)
                ref_img = sample['ref_img']
                scribble_label = sample['scribble_label']
                ref_img = ref_img.unsqueeze(0)
                scribble_label =scribble_label.unsqueeze(0)
                ref_img= ref_img.cuda()
                scribble_label = scribble_label.cuda()
                ######
                ref_scribble_to_show = scribble_label.cpu().squeeze().numpy()
                im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P')
                im_.putpalette(_palette)
                ref_img_name= scr_f
    
                if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                    os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png'))

                #ref_frame_embedding = model.extract_feature(ref_img)

                if first_scribble:
                    ref_frame_embedding = model.extract_feature(ref_img)
                    _,channel,emb_h,emb_w = ref_frame_embedding.size()
                    embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w))
                    embedding_memory = embedding_memory.cuda()
                    embedding_memory[start_annotated_frame]=ref_frame_embedding[0]
                    
                else:
                    ref_frame_embedding = embedding_memory[start_annotated_frame]
                    ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                ########
                if first_scribble:

                    prev_label= None
                else:
                    prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png')
                    prev_label = Image.open(prev_label)
                    prev_label = np.array(prev_label,dtype=np.uint8)
                    prev_label = tr.ToTensor()({'label':prev_label})
                    prev_label = prev_label['label'].unsqueeze(0)
                    prev_label = prev_label.cuda()

                ###############
                #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding,
                #                            ref_frame_embedding,scribble_label,prev_label,
                #                            normalize_nearest_neighbor_distances=True,use_local_map=False,
                #                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                #                        global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead)
                tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label,
                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,
                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble)
                pred_label = tmp_dic[sequence]
                pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True)    

                pred_label=torch.argmax(pred_label,dim=1)
                pred_masks.append(pred_label.float())
                    ####
                if is_save_image:
                    pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                    im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                    im.putpalette(_palette)
                    imgname = str(start_annotated_frame)
                    while len(imgname)<5:
                        imgname='0'+imgname
                    if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                        os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                    im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                ##############################
                ref_prev_label = pred_label.unsqueeze(0)
                prev_label = pred_label.unsqueeze(0)
                prev_img = ref_img
                prev_embedding = ref_frame_embedding

                for ii in range(start_annotated_frame+1,total_frame_num):
                    print('evaluating sequence:{} frame:{}'.format(sequence,ii))
                    sample = eval_data_manager.get_image(ii)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                #    current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[ii]=current_embedding[0]
                    else:
                        current_embedding=embedding_memory[ii]
                        current_embedding = current_embedding.unsqueeze(0)


                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,
                                        interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,
                                        frame_num=[ii],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(ii)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                prev_label = ref_prev_label
                prev_img = ref_img
                prev_embedding = ref_frame_embedding
                ######################################
                for ii in range(start_annotated_frame):
                    current_frame_num=start_annotated_frame-1-ii
                    print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num))
                    sample = eval_data_manager.get_image(current_frame_num)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                    #current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[current_frame_num]=current_embedding[0]
                    else:
                        current_embedding = embedding_memory[current_frame_num]
                        current_embedding = current_embedding.unsqueeze(0)
                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks_reverse.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(current_frame_num)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                pred_masks_reverse.reverse()
                pred_masks_reverse.extend(pred_masks)
                final_masks = torch.cat(pred_masks_reverse,0)
                sess.submit_masks(final_masks.cpu().numpy())
                t_end = timeit.default_timer()
                print('Total time for single interaction: ' + str(t_end - t_total))
            report = sess.get_report()  
            summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))              
예제 #6
0
     net = UNet(1, args.num_classes).cuda()
     if args.is_pretrain:
         print('loading pretrain')
         state_dict = torch.load('/home/viplab/data/unet_carvana_scale1_epoch5.pth')
         model_dict = net.state_dict()
         pretrained_state = { k:v for k,v in state_dict.items() \
                             if k in model_dict and v.size() == model_dict[k].size() }
         model_dict.update(pretrained_state)
         net.load_state_dict(model_dict)
 elif args.model == 'denseunet':
     if args.is_pretrain:
         net = DenseUNet(args.num_classes, pretrained_encoder_uri='https://download.pytorch.org/models/densenet121-a639ec97.pth').cuda()
     else:
         net = DenseUNet(args.num_classes).cuda()
 elif args.model == 'deeplab':
     net = DeepLab(sync_bn=False, num_classes=args.num_classes).cuda()
 elif args.model == 'deeplab_xception':
     net = DeepLab(sync_bn=False, num_classes=args.num_classes, backbone='xception', output_stride=8).cuda()
 elif args.model == 'deeplab_resnest':
     MODEL_CFG = MODEL_CFG.copy()
     MODEL_CFG.update({'num_classes':args.num_classes})
     if args.norm == 'groupnorm':
         MODEL_CFG.update({'norm_cfg': {'type': 'groupnorm', 'opts': {}}})
     net = Deeplabv3Plus(MODEL_CFG, mode='TRAIN').cuda()
     if args.is_pretrain: # resume
         state_dict = torch.load(args.is_pretrain)
         print('loading', args.is_pretrain)
         if 'model' in state_dict.keys():
             state_dict = state_dict['model']
             new_state_dict = {}
             for k in state_dict.keys():