Esempio n. 1
0
    def test_VOS(self,use_gpu=True):
        seqs = []

        with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f:
            seqs_tmp = f.readlines()
        seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
        seqs.extend(seqs_tmp)
        print('model loading...')
        saved_model_dict = os.path.join(self.save_res_dir,'save_step_60000.pth')
        pretrained_dict = torch.load(saved_model_dict)
        self.model = self.load_network(self.model,pretrained_dict)
        print('model load finished')

        self.model.eval()
        with torch.no_grad():
            for seq_name in seqs:
                print('prcessing seq:{}'.format(seq_name))
                test_dataset =DAVIS2017_VOS_Test(root = cfg.DATA_ROOT,transform=tr.ToTensor(),result_root=cfg.RESULT_ROOT,seq_name=seq_name)
                test_dataloader=DataLoader(test_dataset,batch_size=1,
                        shuffle=False,num_workers=0,pin_memory=True)
                if not os.path.exists(os.path.join(cfg.RESULT_ROOT,seq_name)):
                    os.makedirs(os.path.join(cfg.RESULT_ROOT,seq_name))
                time_start= time.time()
                for ii,sample in enumerate(test_dataloader):
                    ref_img = sample['ref_img']
                    prev_img = sample['prev_img']
                    current_img = sample['current_img']
                    ref_label = sample['ref_label']
                    prev_label =sample['prev_label']
                    obj_num = sample['meta']['obj_num']
                    seqnames = sample['meta']['seq_name']
                    imgname = sample['meta']['current_name']
                    bs,_,h,w = current_img.size()

                    inputs = torch.cat((ref_img,prev_img,current_img),0)
                    if use_gpu:
                        inputs = inputs.cuda()
                        ref_label=ref_label.cuda()
                        prev_label=prev_label.cuda()
##                    tmp_dic = self.model(inputs,ref_label,prev_label,seq_names=seqnames,gt_ids=obj_num,k_nearest_neighbors=cfg.KNNS)

                    ################
                    t1 = time.time()
                    tmp = self.model.extract_feature(inputs)
                    ref_frame_embedding, previous_frame_embedding,current_frame_embedding = torch.split(tmp,split_size_or_sections=int(tmp.size(0)/3), dim=0)
                    t2 = time.time()
                    print('feature_extracter time:{}'.format(t2-t1))
                    tmp_dic = self.model.before_seghead_process(ref_frame_embedding,previous_frame_embedding,
                    current_frame_embedding,ref_label,prev_label,
                    True,
                    seqnames,obj_num,cfg.KNNS,self.model.dynamic_seghead)
                    t3 = time.time()
                    print('after time:{}'.format(t3-t2))

                    #######################
                    pred_label = tmp_dic[seq_name].cpu()
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_label = pred_label.squeeze(0)
                    #pred_label = pred_label.permute(1,2,0)
                    pred_label=pred_label.numpy()
                    im = Image.fromarray(pred_label.astype('uint8')).convert('P')
                    im.putpalette(_palette)
                    im.save(os.path.join(cfg.RESULT_ROOT,seq_name,imgname[0].split('.')[0]+'.png'))
#                    cv2.imwrite(os.path.join('./result_2',seq_name,imgname[0].split('.')[0]+'.png'),pred_label)
                    one_frametime= time.time()
                    print('seq name:{} frame:{} time:{}'.format(seq_name,imgname[0],one_frametime-time_start))
                    time_start=time.time()
Esempio n. 2
0
    def train_VOS(self,damage_initial_previous_frame_mask=False,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_30000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=30000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
            for ii, sample in enumerate(trainloader):
                now_lr=self._adjust_lr(optimizer,step,max_itr)
                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                img1s = sample['img1'] 
                img2s = sample['img2']
                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                label1s = sample['label1']
                label2s = sample['label2']
                seq_names = sample['meta']['seq_name'] 
                obj_nums = sample['meta']['obj_num']
                bs,_,h,w = img2s.size()
                inputs = torch.cat((ref_imgs,img1s,img2s),0)
                if damage_initial_previous_frame_mask:
                    try:
                        label1s = damage_masks(label1s)
                    except:
                        label1s = label1s
                        print('damage_error')




                ##########
                if self.use_gpu:
                    inputs = inputs.cuda()
                    ref_scribble_labels=ref_scribble_labels.cuda()
                    label1s = label1s.cuda()
                    label2s = label2s.cuda()
                 
                ##########


                tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                label_and_obj_dic={}
                label_dic={}
                for i, seq_ in enumerate(seq_names):
                    label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                for seq_ in tmp_dic.keys():
                    tmp_pred_logits = tmp_dic[seq_]
                    tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                    tmp_dic[seq_]=tmp_pred_logits

                    label_tmp,obj_num = label_and_obj_dic[seq_]
                    obj_ids = np.arange(1,obj_num+1)
                    obj_ids = torch.from_numpy(obj_ids)
                    obj_ids = obj_ids.int()
                    if torch.cuda.is_available():
                        obj_ids = obj_ids.cuda()
                    if lossfunc == 'bce':
                        label_tmp = label_tmp.permute(1,2,0)
                        label = (label_tmp.float()==obj_ids.float())
                        label = label.unsqueeze(-1).permute(3,2,0,1)
                        label_dic[seq_]=label.float()
                    elif lossfunc =='cross_entropy':
                        label_dic[seq_]=label_tmp.long()



                loss = criterion(tmp_dic,label_dic,step)
                loss =loss/bs
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #scheduler.step()
                
                running_loss.update(loss.item(),bs)
                if step%50==0:
                    #print(torch.cuda.memory_allocated())
                    print(torch.cuda.max_memory_cached())
                    torch.cuda.empty_cache()
                    torch.cuda.reset_max_memory_allocated()
                    print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                #    print('step:{}'.format(step))
                    
                    show_ref_img = ref_imgs.cpu().numpy()[0]
                    show_img1 = img1s.cpu().numpy()[0]
                    show_img2 = img2s.cpu().numpy()[0]

                    mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                    sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                    show_ref_img = show_ref_img*sigma+mean
                    show_img1 = show_img1*sigma+mean
                    show_img2 = show_img2*sigma+mean


                    show_gt = label2s.cpu()[0]

                    show_gt = show_gt.squeeze(0).numpy()
                    show_gtf = label2colormap(show_gt).transpose((2,0,1))

                    ##########
                    show_preds = tmp_dic[seq_names[0]].cpu()
                    show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                    show_preds = show_preds.squeeze(0)
                    if lossfunc=='bce':
                        show_preds = (torch.sigmoid(show_preds)>0.5)
                        show_preds_s = torch.zeros((h,w))
                        for i in range(show_preds.size(0)):
                            show_preds_s[show_preds[i]]=i+1
                    elif lossfunc=='cross_entropy':
                        show_preds_s = torch.argmax(show_preds,dim=0)
                    show_preds_s = show_preds_s.numpy()
                    show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                    pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                    tblogger.add_scalar('loss', running_loss.avg, step)
                    tblogger.add_scalar('pix_acc', pix_acc, step)
                    tblogger.add_scalar('now_lr', now_lr, step)
                    tblogger.add_image('Reference image', show_ref_img, step)
                    tblogger.add_image('Previous frame image', show_img1, step)
                    tblogger.add_image('Current frame image', show_img2, step)
                    tblogger.add_image('Groud Truth', show_gtf, step)
                    tblogger.add_image('Predict label', show_preds_sf, step)




                    ###########TODO
                if step%5000==0 and step!=0:
                    self.save_network(self.model,step)



                step+=1
Esempio n. 3
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False):
        ###################
        interactor = interactive_robot.InteractiveScribblesRobot()
        self.model.train()
        running_loss = AverageMeter()
        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
#        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_list = train_dataset.seqs

#        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)

#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0
        round_=3
        epoch_per_round=30
        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=75000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:
            
            if step>100001:
                break

            for r in range(round_):
                if r==0:
                    print('start new')
                    global_map_tmp_dic={}
                    train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
                    train_dataset.init_ref_frame_dic()

                trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                print('round:{} start'.format(r))
                for epoch in range(epoch_per_round):



                    for ii, sample in enumerate(trainloader):
                        now_lr=self._adjust_lr(optimizer,step,max_itr)
                        ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                        #img1s = sample['img1'] 
                        #img2s = sample['img2']
                        ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                        #label1s = sample['label1']
                        #label2s = sample['label2']
                        seq_names = sample['meta']['seq_name'] 
                        obj_nums = sample['meta']['obj_num']
                        #frame_nums = sample['meta']['frame_num']
                        ref_frame_nums = sample['meta']['ref_frame_num']
                        ref_frame_gts=sample['ref_frame_gt']
                        bs,_,h,w = ref_imgs.size()
#                        print(ref_imgs.size())

#                        if r==0:
#                            ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                        ##########
                        if self.use_gpu:
                            inputs = ref_imgs.cuda()
                        
                            ref_scribble_labels=ref_scribble_labels.cuda()
                            ref_frame_gts = ref_frame_gts.cuda()
                            #label1s = label1s.cuda()
                            #label2s = label2s.cuda()
                        #print(inputs.size()) 
                        ##########        
                        with torch.no_grad():
                            self.model.feature_extracter.eval()
                            self.model.semantic_embedding.eval()
                            ref_frame_embedding = self.model.extract_feature(inputs)
                        if r==0:
                            first_inter=True

                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)
                        else:
                            first_inter=False
                            prev_round_label=sample['prev_round_label']
                        #    print(prev_round_label.size())
                            #prev_round_label=prev_round_label_dic[seq_names[0]]
                            prev_round_label=prev_round_label.cuda()
                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)




                        label_and_obj_dic={}
                        label_dic={}
                        for i, seq_ in enumerate(seq_names):
                            label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i])
                        for seq_ in tmp_dic.keys():
                            tmp_pred_logits = tmp_dic[seq_]
                            tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                            tmp_dic[seq_]=tmp_pred_logits        

                            label_tmp,obj_num = label_and_obj_dic[seq_]
                            obj_ids = np.arange(0,obj_num+1)
                            obj_ids = torch.from_numpy(obj_ids)
                            obj_ids = obj_ids.int()
                            if torch.cuda.is_available():
                                obj_ids = obj_ids.cuda()
                            if lossfunc == 'bce':
                                label_tmp = label_tmp.permute(1,2,0)
                                label = (label_tmp.float()==obj_ids.float())
                                label = label.unsqueeze(-1).permute(3,2,0,1)
                                label_dic[seq_]=label.float()
                            elif lossfunc =='cross_entropy':
                                label_dic[seq_]=label_tmp.long()        
        
        
                        loss = criterion(tmp_dic,label_dic,step)
                        loss =loss/bs
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        #scheduler.step()
                        

                        running_loss.update(loss.item(),bs)
                        if step%50==0:
                            #print(torch.cuda.memory_allocated())
                            #print(torch.cuda.max_memory_cached())
                            torch.cuda.empty_cache()
                            #torch.cuda.reset_max_memory_allocated()
                            print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #    print('step:{}'.format(step))
                            
                            show_ref_img = ref_imgs.cpu().numpy()[0]
                            #show_img1 = img1s.cpu().numpy()[0]
                            #show_img2 = img2s.cpu().numpy()[0]        

                            mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                            sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])        

                            show_ref_img = show_ref_img*sigma+mean
                            #show_img1 = show_img1*sigma+mean
                            #show_img2 = show_img2*sigma+mean        
        

                            #show_gt = label2s.cpu()[0]        

                            show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy()
                            show_gtf = label2colormap(show_gt).transpose((2,0,1))        
                            show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy()
                            show_scrbble=label2colormap(show_scrbble).transpose((2,0,1))
                            if r!=0:
                                show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy()
                                show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1))
                            else:
                                show_prev_round_label = np.zeros_like(show_gt)
                                
                                show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1))

                            ##########
                            show_preds = tmp_dic[seq_names[0]].cpu()
                            show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                            show_preds = show_preds.squeeze(0)
                            if lossfunc=='bce':
                                show_preds = show_preds[1:]

                                show_preds = (torch.sigmoid(show_preds)>0.5)
                                marker = torch.argmax(show_preds,dim=0)
                                show_preds_s = torch.zeros((h,w))
                                for i in range(show_preds.size(0)):
                                    tmp_mask = (marker==i) & (show_preds[i]>0.5)
                                    show_preds_s[tmp_mask]=i+1
                            elif lossfunc=='cross_entropy':
                                show_preds_s = torch.argmax(show_preds,dim=0)
                            show_preds_s = show_preds_s.numpy()
                            show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))        

                            pix_acc = np.sum(show_preds_s==show_gt)/(h*w)        
        
        
        
                            if cfg.TRAIN_TBLOG:
                                tblogger.add_scalar('loss', running_loss.avg, step)
                                tblogger.add_scalar('pix_acc', pix_acc, step)
                                tblogger.add_scalar('now_lr', now_lr, step)
                                tblogger.add_image('Reference image', show_ref_img, step)
                                #tblogger.add_image('Previous frame image', show_img1, step)
                                
                           #     tblogger.add_image('Current frame image', show_img2, step)
                                tblogger.add_image('Groud Truth', show_gtf, step)
                                tblogger.add_image('Predict label', show_preds_sf, step)        
        
                                tblogger.add_image('Scribble', show_scrbble, step)
                                tblogger.add_image('prev_round_label', show_prev_round_label, step)

                            ###########TODO
                        if step%20000==0 and step!=0:
                            self.save_network(self.model,step)        
        
        
                
                
                        step+=1
                print('trainset evaluating...')
                print('*'*100)


                if cfg.TRAIN_INTER_USE_TRUE_RESULT:
                    if r!=round_-1:
                        if r ==0:
                            prev_round_label_dic={}
                        self.model.eval()
                        with torch.no_grad():
                            round_scribble={}

                            frame_num_dic= {}
                            train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()])
#                            train_dataset.transform=composed_transforms
                            trainloader = DataLoader(train_dataset,batch_size=1,
                                sampler = RandomIdentitySampler(train_dataset.sample_list), 
                                shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                inputs = torch.cat((ref_imgs,img1s,img2s),0)
                                if r==0:
                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
#                                    if damage_initial_previous_frame_mask:
#                                        try:
#                                            label1s = damage_masks(label1s)
#                                        except:
#                                            label1s = label1s
#                                            print('damage_error') 
                                label1s_tocat=torch.Tensor()
                                for i in range(bs):
                                    l = label1s[i]
                                    l = l.unsqueeze(0)
                                    l = mask_damager(l,0.0)
                                    l = torch.from_numpy(l)

                                    l = l.unsqueeze(0).unsqueeze(0)
                        
                                    label1s_tocat = torch.cat((label1s_tocat,l.float()),0)
                                label1s = label1s_tocat
                                if self.use_gpu:
                                    inputs = inputs.cuda()
                                    ref_scribble_labels=ref_scribble_labels.cuda()
                                    label1s = label1s.cuda()
                                    
                                tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                                            gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                                           frame_num=frame_nums)
                                pred_label = tmp_dic[seq_names[0]].detach().cpu()
                                pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)            

                                pred_label=torch.argmax(pred_label,dim=1)
                                pred_label= pred_label.unsqueeze(0)
                                try:
                                    pred_label=damage_masks(pred_label)
                                except:
                                    pred_label=pred_label
                                pred_label=pred_label.squeeze(0)
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums)
                                frame_num_dic[seq_names[0]]=frame_nums[0]        
                                pred_label=pred_label.unsqueeze(0)
                                img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                img_ww=np.array(img_ww)
                                or_h,or_w = img_ww.shape[:2]
                                pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest')

                                prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0)
#                                torch.cuda.empty_cache() 
                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                    print('trainset evaluating finished!')
                    print('*'*100)
                    self.model.train()
                    print('updating ref frame and label')


                    train_dataset.transform=composed_transforms 
                    print('updating ref frame and label finished!')



                else:
                    if r!=round_-1:
                        round_scribble={}    

                        if r ==0:
                            prev_round_label_dic={}
                        #    eval_global_map_tmp_dic={}
                        frame_num_dic= {}
                        train_dataset.transform=tr.ToTensor()
#                        train_dataset.transform=tr.ToTensor()
                        trainloader = DataLoader(train_dataset,batch_size=1,
                            sampler = RandomIdentitySampler(train_dataset.sample_list), 
                            shuffle=False,num_workers=1,pin_memory=True)
                        
                        self.model.eval()
                        with torch.no_grad():
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                
 #                               inputs=torch.cat((ref_imgs,img1s,img2s),0)
#                                if r==0:
#                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
                                label2s_ = mask_damager(label2s,0.1)
                                # inputs = inputs.cuda()
                                # ref_scribble_labels=ref_scribble_labels.cuda()
                                # label1s = label1s.cuda()
                                
                                  
                                        
                                # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                #                                 gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                #                                frame_num=frame_nums)    

                                #label2s_show = label2s.squeeze().numpy()
                                #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P')
                                #label2s_im.putpalette(_palette)
                                #label2s_im.save('label.png')
                                #label2s__show = label2s_
                                #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P')
                                #label2s_im_.putpalette(_palette)
                                #label2s_im_.save('damage_label.png')
                        #        exit()
                                
                                                            
                                #print(label2s_)
                                #print(label2s.size())
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums)
                                label2s__=torch.from_numpy(label2s_)
                                # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                # img_ww=np.array(img_ww)
                                # or_h,or_w = img_ww.shape[:2]
                                # label2s__=label2s__.unsqueeze(0).unsqueeze(0)        
                                # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest')
                                # label2s__=label2s__.squeeze(0)
#                                print(label2s__.size())

                                frame_num_dic[seq_names[0]]=frame_nums[0]
                                prev_round_label_dic[seq_names[0]]=label2s__    
    

                                #torch.cuda.empty_cache()     
                        print('trainset evaluating finished!')
                        print('*'*100)
                        print('updating ref frame and label')    

                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                        self.model.train()
                        train_dataset.transform=composed_transforms 
                        print('updating ref frame and label finished!')
def main():
    total_frame_num_dic={}
    #################
    seqs = []
    with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f:
        seqs_tmp = f.readlines()
        seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
        seqs.extend(seqs_tmp)
    for seq_name in seqs:
        images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip())))
        total_frame_num_dic[seq_name]=len(images)
    _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017',
                                          'v_a_l' + '_instances.txt')
    seq_dict = json.load(open(_seq_list_file, 'r'))

    ##################

    is_save_image=True
    report_save_dir= cfg.RESULT_ROOT
    save_res_dir = './saved_model_inter_net_new_re'
        # Configuration used in the challenges
    max_nb_interactions = 8  # Maximum number of interactions
    max_time_per_interaction = 10000  # Maximum time per interaction per object
    # Total time available to interact with a sequence and an initial set of scribbles
    max_time = max_nb_interactions * max_time_per_interaction  # Maximum time per object
    # Interactive parameters
    subset = 'val'
    host = 'localhost'  # 'localhost' for subsets train and val.

    feature_extracter = DeepLab(backbone='resnet',freeze_bn=False)
    model = IntVOS(cfg,feature_extracter)
    model= model.cuda()
    print('model loading...')

    saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth')
    pretrained_dict = torch.load(saved_model_dict)
    load_network(model,pretrained_dict)

    print('model loading finished!')
    model.eval()

    seen_seq={}
    with torch.no_grad():
        with DavisInteractiveSession(host=host,
                                davis_root=cfg.DATA_ROOT,
                                subset=subset,
                                report_save_dir=report_save_dir,
                                max_nb_interactions=max_nb_interactions,
                                max_time=max_time,
                                metric_to_optimize='J'
                                                            
                                ) as sess:
            while sess.next():
                t_total=timeit.default_timer()
                # Get the current iteration scribbles

                sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True)
                print(sequence)
                start_annotated_frame=annotated_frames(scribbles)[0]
                
                pred_masks=[]
                pred_masks_reverse=[]


                if first_scribble:
                    anno_frame_list=[]
                    n_interaction=1
                    eval_global_map_tmp_dic={}
                    local_map_dics=({},{})
                    total_frame_num=total_frame_num_dic[sequence]
                    obj_nums = seq_dict[sequence][-1]
                    eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(),
                                                            seq_name=sequence)

                else:
                    n_interaction+=1



                ##########################Reference image process
                scr_f = start_annotated_frame
                anno_frame_list.append(start_annotated_frame)
                print(start_annotated_frame)
                scr_f = str(scr_f)
                while len(scr_f)!=5:
                    scr_f='0'+scr_f
                ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg')
                ref_img = cv2.imread(ref_img)
                h_,w_ = ref_img.shape[:2]
                ref_img = np.array(ref_img,dtype=np.float32)

                #ref_img = tr.ToTensor()(ref_img)
                #ref_img = ref_img.unsqueeze(0)
                scribble_masks=scribbles2mask(scribbles,(h_,w_))
                scribble_label=scribble_masks[start_annotated_frame]
                sample = {'ref_img':ref_img,'scribble_label':scribble_label}
                sample = tr.ToTensor()(sample)
                ref_img = sample['ref_img']
                scribble_label = sample['scribble_label']
                ref_img = ref_img.unsqueeze(0)
                scribble_label =scribble_label.unsqueeze(0)
                ref_img= ref_img.cuda()
                scribble_label = scribble_label.cuda()
                ######
                ref_scribble_to_show = scribble_label.cpu().squeeze().numpy()
                im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P')
                im_.putpalette(_palette)
                ref_img_name= scr_f
    
                if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                    os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png'))

                #ref_frame_embedding = model.extract_feature(ref_img)

                if first_scribble:
                    ref_frame_embedding = model.extract_feature(ref_img)
                    _,channel,emb_h,emb_w = ref_frame_embedding.size()
                    embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w))
                    embedding_memory = embedding_memory.cuda()
                    embedding_memory[start_annotated_frame]=ref_frame_embedding[0]
                    
                else:
                    ref_frame_embedding = embedding_memory[start_annotated_frame]
                    ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                ########
                if first_scribble:

                    prev_label= None
                else:
                    prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png')
                    prev_label = Image.open(prev_label)
                    prev_label = np.array(prev_label,dtype=np.uint8)
                    prev_label = tr.ToTensor()({'label':prev_label})
                    prev_label = prev_label['label'].unsqueeze(0)
                    prev_label = prev_label.cuda()

                ###############
                #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding,
                #                            ref_frame_embedding,scribble_label,prev_label,
                #                            normalize_nearest_neighbor_distances=True,use_local_map=False,
                #                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                #                        global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead)
                tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label,
                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,
                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble)
                pred_label = tmp_dic[sequence]
                pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True)    

                pred_label=torch.argmax(pred_label,dim=1)
                pred_masks.append(pred_label.float())
                    ####
                if is_save_image:
                    pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                    im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                    im.putpalette(_palette)
                    imgname = str(start_annotated_frame)
                    while len(imgname)<5:
                        imgname='0'+imgname
                    if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                        os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                    im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                ##############################
                ref_prev_label = pred_label.unsqueeze(0)
                prev_label = pred_label.unsqueeze(0)
                prev_img = ref_img
                prev_embedding = ref_frame_embedding

                for ii in range(start_annotated_frame+1,total_frame_num):
                    print('evaluating sequence:{} frame:{}'.format(sequence,ii))
                    sample = eval_data_manager.get_image(ii)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                #    current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[ii]=current_embedding[0]
                    else:
                        current_embedding=embedding_memory[ii]
                        current_embedding = current_embedding.unsqueeze(0)


                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,
                                        interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,
                                        frame_num=[ii],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(ii)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                prev_label = ref_prev_label
                prev_img = ref_img
                prev_embedding = ref_frame_embedding
                ######################################
                for ii in range(start_annotated_frame):
                    current_frame_num=start_annotated_frame-1-ii
                    print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num))
                    sample = eval_data_manager.get_image(current_frame_num)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                    #current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[current_frame_num]=current_embedding[0]
                    else:
                        current_embedding = embedding_memory[current_frame_num]
                        current_embedding = current_embedding.unsqueeze(0)
                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks_reverse.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(current_frame_num)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                pred_masks_reverse.reverse()
                pred_masks_reverse.extend(pred_masks)
                final_masks = torch.cat(pred_masks_reverse,0)
                sess.submit_masks(final_masks.cpu().numpy())
                t_end = timeit.default_timer()
                print('Total time for single interaction: ' + str(t_end - t_total))
            report = sess.get_report()  
            summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))