예제 #1
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False):
        ###################
        self.model.train()
        running_loss = AverageMeter()
        #optimizer = optim.SGD([{'params':self.model.feature_extracter.parameters()},{'params':self.model.semantic_embedding.parameters()},{'params':self.model.dynamic_seghead.parameters()}],lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        composed_transforms_ytb = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale([0.5,1,1.25]),
                                                     tr.RandomCrop((800,800)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
#        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        ytb_train_dataset = YTB_VOS_Train(root=cfg.YTB_DATAROOT,transform=composed_transforms_ytb)
#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
#                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_davis = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        trainloader_ytb = DataLoader(ytb_train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        #trainloader=[trainloader_ytb,trainloader_davis]
        trainloader=[trainloader_ytb,trainloader_davis]
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0

        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_60000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=60000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:

            for train_dataloader in trainloader:
 #       sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num}
                for ii, sample in enumerate(train_dataloader):
        #            print(ii)
                    now_lr=self._adjust_lr(optimizer,step,max_itr)
                    ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                    img1s = sample['img1'] 
                    img2s = sample['img2']
                    ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                    label1s = sample['label1']
                    label2s = sample['label2']
                    seq_names = sample['meta']['seq_name'] 
                    obj_nums = sample['meta']['obj_num']
                    bs,_,h,w = img2s.size()
                    inputs = torch.cat((ref_imgs,img1s,img2s),0)
                    if damage_initial_previous_frame_mask:
                        try:
                            label1s = damage_masks(label1s)
                        except:
                            label1s = label1s
                            print('damage_error')




                    ##########
                    if self.use_gpu:
                        inputs = inputs.cuda()
                        ref_scribble_labels=ref_scribble_labels.cuda()
                        label1s = label1s.cuda()
                        label2s = label2s.cuda()
                     
                    ##########


                    tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS)

                    label_and_obj_dic={}
                    label_dic={}
                    for i, seq_ in enumerate(seq_names):
                        label_and_obj_dic[seq_]=(label2s[i],obj_nums[i])
                    for seq_ in tmp_dic.keys():
                        tmp_pred_logits = tmp_dic[seq_]
                        tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                        tmp_dic[seq_]=tmp_pred_logits

                        label_tmp,obj_num = label_and_obj_dic[seq_]
                        obj_ids = np.arange(1,obj_num+1)
                        obj_ids = torch.from_numpy(obj_ids)
                        obj_ids = obj_ids.int()
                        if torch.cuda.is_available():
                            obj_ids = obj_ids.cuda()
                        if lossfunc == 'bce':
                            label_tmp = label_tmp.permute(1,2,0)
                            label = (label_tmp.float()==obj_ids.float())
                            label = label.unsqueeze(-1).permute(3,2,0,1)
                            label_dic[seq_]=label.float()
                        elif lossfunc =='cross_entropy':
                            label_dic[seq_]=label_tmp.long()



                    loss = criterion(tmp_dic,label_dic,step)
                    loss =loss/bs
######################################
                     
                    if loss.item()>10000:
                        print(tmp_dic)
                        for k,v in tmp_dic.items():
                            v = v.cpu()
                            v = v.detach().numpy()
                            np.save(k+'.npy',v)
                            l=label_dic[k]
                            l=l.cpu().detach().numpy()
                            np.save('lab'+k+'.npy',l)
                        #continue 
                        exit()
##########################################
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    #scheduler.step()
                    
                    running_loss.update(loss.item(),bs)
                    if step%1==0:
                        #print(torch.cuda.memory_allocated())
                        #print(torch.cuda.max_memory_cached())
                        #torch.cuda.empty_cache()
                        #torch.cuda.reset_max_memory_allocated()
                        print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #print(tmp_dic)
                        #print(seq_names)
                    #    print('step:{}'.format(step))
                        
                        show_ref_img = ref_imgs.cpu().numpy()[0]
                        show_img1 = img1s.cpu().numpy()[0]
                        show_img2 = img2s.cpu().numpy()[0]

                        mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                        sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])

                        show_ref_img = show_ref_img*sigma+mean
                        show_img1 = show_img1*sigma+mean
                        show_img2 = show_img2*sigma+mean


                        show_gt = label2s.cpu()[0]

                        show_gt = show_gt.squeeze(0).numpy()
                        show_gtf = label2colormap(show_gt).transpose((2,0,1))

                        ##########
                        show_preds = tmp_dic[seq_names[0]].cpu()
                        show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                        show_preds = show_preds.squeeze(0)
                        if lossfunc=='bce':
                            show_preds = (torch.sigmoid(show_preds)>0.5)
                            show_preds_s = torch.zeros((h,w))
                            for i in range(show_preds.size(0)):
                                show_preds_s[show_preds[i]]=i+1
                        elif lossfunc=='cross_entropy':
                            show_preds_s = torch.argmax(show_preds,dim=0)
                        show_preds_s = show_preds_s.numpy()
                        show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))

                        pix_acc = np.sum(show_preds_s==show_gt)/(h*w)




                        tblogger.add_scalar('loss', running_loss.avg, step)
                        tblogger.add_scalar('pix_acc', pix_acc, step)
                        tblogger.add_scalar('now_lr', now_lr, step)
                        tblogger.add_image('Reference image', show_ref_img, step)
                        tblogger.add_image('Previous frame image', show_img1, step)
                        tblogger.add_image('Current frame image', show_img2, step)
                        tblogger.add_image('Groud Truth', show_gtf, step)
                        tblogger.add_image('Predict label', show_preds_sf, step)




                        ###########TODO
                    if step%5000==0 and step!=0:
                        self.save_network(self.model,step)



                    step+=1
예제 #2
0
    def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False):
        ###################
        interactor = interactive_robot.InteractiveScribblesRobot()
        self.model.train()
        running_loss = AverageMeter()
        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY)
#        optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM)
        #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA)
        

        ###################

        composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
        print('dataset processing...')
        train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms)
        train_list = train_dataset.seqs

#        train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms)

#        trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
#                        shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True)
        print('dataset processing finished.')
        if lossfunc=='bce':
            criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        elif lossfunc=='cross_entropy':
            criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP)
        else:
            print('unsupported loss funciton. Please choose from [cross_entropy,bce]')

        max_itr = cfg.TRAIN_TOTAL_STEPS

        step=0
        round_=3
        epoch_per_round=30
        if model_resume:
            saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth')

            saved_model_ = torch.load(saved_model_)
            self.model=self.load_network(self.model,saved_model_)
            step=75000
            print('resume from step {}'.format(step))
        while step<cfg.TRAIN_TOTAL_STEPS:
            
            if step>100001:
                break

            for r in range(round_):
                if r==0:
                    print('start new')
                    global_map_tmp_dic={}
                    train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP),
                                                    tr.RandomScale(),
                                                     tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)),
                                                     tr.Resize(cfg.DATA_RESCALE),
                                                     tr.ToTensor()])
                    train_dataset.init_ref_frame_dic()

                trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE,
                        sampler = RandomIdentitySampler(train_dataset.sample_list), 
                        shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                print('round:{} start'.format(r))
                for epoch in range(epoch_per_round):



                    for ii, sample in enumerate(trainloader):
                        now_lr=self._adjust_lr(optimizer,step,max_itr)
                        ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                        #img1s = sample['img1'] 
                        #img2s = sample['img2']
                        ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                        #label1s = sample['label1']
                        #label2s = sample['label2']
                        seq_names = sample['meta']['seq_name'] 
                        obj_nums = sample['meta']['obj_num']
                        #frame_nums = sample['meta']['frame_num']
                        ref_frame_nums = sample['meta']['ref_frame_num']
                        ref_frame_gts=sample['ref_frame_gt']
                        bs,_,h,w = ref_imgs.size()
#                        print(ref_imgs.size())

#                        if r==0:
#                            ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                        ##########
                        if self.use_gpu:
                            inputs = ref_imgs.cuda()
                        
                            ref_scribble_labels=ref_scribble_labels.cuda()
                            ref_frame_gts = ref_frame_gts.cuda()
                            #label1s = label1s.cuda()
                            #label2s = label2s.cuda()
                        #print(inputs.size()) 
                        ##########        
                        with torch.no_grad():
                            self.model.feature_extracter.eval()
                            self.model.semantic_embedding.eval()
                            ref_frame_embedding = self.model.extract_feature(inputs)
                        if r==0:
                            first_inter=True

                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)
                        else:
                            first_inter=False
                            prev_round_label=sample['prev_round_label']
                        #    print(prev_round_label.size())
                            #prev_round_label=prev_round_label_dic[seq_names[0]]
                            prev_round_label=prev_round_label.cuda()
                            tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels,
                                    prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={},
                                 seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,
                                 frame_num=ref_frame_nums,first_inter=first_inter)




                        label_and_obj_dic={}
                        label_dic={}
                        for i, seq_ in enumerate(seq_names):
                            label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i])
                        for seq_ in tmp_dic.keys():
                            tmp_pred_logits = tmp_dic[seq_]
                            tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True)
                            tmp_dic[seq_]=tmp_pred_logits        

                            label_tmp,obj_num = label_and_obj_dic[seq_]
                            obj_ids = np.arange(0,obj_num+1)
                            obj_ids = torch.from_numpy(obj_ids)
                            obj_ids = obj_ids.int()
                            if torch.cuda.is_available():
                                obj_ids = obj_ids.cuda()
                            if lossfunc == 'bce':
                                label_tmp = label_tmp.permute(1,2,0)
                                label = (label_tmp.float()==obj_ids.float())
                                label = label.unsqueeze(-1).permute(3,2,0,1)
                                label_dic[seq_]=label.float()
                            elif lossfunc =='cross_entropy':
                                label_dic[seq_]=label_tmp.long()        
        
        
                        loss = criterion(tmp_dic,label_dic,step)
                        loss =loss/bs
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        #scheduler.step()
                        

                        running_loss.update(loss.item(),bs)
                        if step%50==0:
                            #print(torch.cuda.memory_allocated())
                            #print(torch.cuda.max_memory_cached())
                            torch.cuda.empty_cache()
                            #torch.cuda.reset_max_memory_allocated()
                            print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg))
                        #    print('step:{}'.format(step))
                            
                            show_ref_img = ref_imgs.cpu().numpy()[0]
                            #show_img1 = img1s.cpu().numpy()[0]
                            #show_img2 = img2s.cpu().numpy()[0]        

                            mean = np.array([[[0.485]], [[0.456]], [[0.406]]])
                            sigma = np.array([[[0.229]], [[0.224]], [[0.225]]])        

                            show_ref_img = show_ref_img*sigma+mean
                            #show_img1 = show_img1*sigma+mean
                            #show_img2 = show_img2*sigma+mean        
        

                            #show_gt = label2s.cpu()[0]        

                            show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy()
                            show_gtf = label2colormap(show_gt).transpose((2,0,1))        
                            show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy()
                            show_scrbble=label2colormap(show_scrbble).transpose((2,0,1))
                            if r!=0:
                                show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy()
                                show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1))
                            else:
                                show_prev_round_label = np.zeros_like(show_gt)
                                
                                show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1))

                            ##########
                            show_preds = tmp_dic[seq_names[0]].cpu()
                            show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True)
                            show_preds = show_preds.squeeze(0)
                            if lossfunc=='bce':
                                show_preds = show_preds[1:]

                                show_preds = (torch.sigmoid(show_preds)>0.5)
                                marker = torch.argmax(show_preds,dim=0)
                                show_preds_s = torch.zeros((h,w))
                                for i in range(show_preds.size(0)):
                                    tmp_mask = (marker==i) & (show_preds[i]>0.5)
                                    show_preds_s[tmp_mask]=i+1
                            elif lossfunc=='cross_entropy':
                                show_preds_s = torch.argmax(show_preds,dim=0)
                            show_preds_s = show_preds_s.numpy()
                            show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1))        

                            pix_acc = np.sum(show_preds_s==show_gt)/(h*w)        
        
        
        
                            if cfg.TRAIN_TBLOG:
                                tblogger.add_scalar('loss', running_loss.avg, step)
                                tblogger.add_scalar('pix_acc', pix_acc, step)
                                tblogger.add_scalar('now_lr', now_lr, step)
                                tblogger.add_image('Reference image', show_ref_img, step)
                                #tblogger.add_image('Previous frame image', show_img1, step)
                                
                           #     tblogger.add_image('Current frame image', show_img2, step)
                                tblogger.add_image('Groud Truth', show_gtf, step)
                                tblogger.add_image('Predict label', show_preds_sf, step)        
        
                                tblogger.add_image('Scribble', show_scrbble, step)
                                tblogger.add_image('prev_round_label', show_prev_round_label, step)

                            ###########TODO
                        if step%20000==0 and step!=0:
                            self.save_network(self.model,step)        
        
        
                
                
                        step+=1
                print('trainset evaluating...')
                print('*'*100)


                if cfg.TRAIN_INTER_USE_TRUE_RESULT:
                    if r!=round_-1:
                        if r ==0:
                            prev_round_label_dic={}
                        self.model.eval()
                        with torch.no_grad():
                            round_scribble={}

                            frame_num_dic= {}
                            train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()])
#                            train_dataset.transform=composed_transforms
                            trainloader = DataLoader(train_dataset,batch_size=1,
                                sampler = RandomIdentitySampler(train_dataset.sample_list), 
                                shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True)
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                inputs = torch.cat((ref_imgs,img1s,img2s),0)
                                if r==0:
                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
#                                    if damage_initial_previous_frame_mask:
#                                        try:
#                                            label1s = damage_masks(label1s)
#                                        except:
#                                            label1s = label1s
#                                            print('damage_error') 
                                label1s_tocat=torch.Tensor()
                                for i in range(bs):
                                    l = label1s[i]
                                    l = l.unsqueeze(0)
                                    l = mask_damager(l,0.0)
                                    l = torch.from_numpy(l)

                                    l = l.unsqueeze(0).unsqueeze(0)
                        
                                    label1s_tocat = torch.cat((label1s_tocat,l.float()),0)
                                label1s = label1s_tocat
                                if self.use_gpu:
                                    inputs = inputs.cuda()
                                    ref_scribble_labels=ref_scribble_labels.cuda()
                                    label1s = label1s.cuda()
                                    
                                tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                                            gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                                           frame_num=frame_nums)
                                pred_label = tmp_dic[seq_names[0]].detach().cpu()
                                pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)            

                                pred_label=torch.argmax(pred_label,dim=1)
                                pred_label= pred_label.unsqueeze(0)
                                try:
                                    pred_label=damage_masks(pred_label)
                                except:
                                    pred_label=pred_label
                                pred_label=pred_label.squeeze(0)
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums)
                                frame_num_dic[seq_names[0]]=frame_nums[0]        
                                pred_label=pred_label.unsqueeze(0)
                                img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                img_ww=np.array(img_ww)
                                or_h,or_w = img_ww.shape[:2]
                                pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest')

                                prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0)
#                                torch.cuda.empty_cache() 
                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                    print('trainset evaluating finished!')
                    print('*'*100)
                    self.model.train()
                    print('updating ref frame and label')


                    train_dataset.transform=composed_transforms 
                    print('updating ref frame and label finished!')



                else:
                    if r!=round_-1:
                        round_scribble={}    

                        if r ==0:
                            prev_round_label_dic={}
                        #    eval_global_map_tmp_dic={}
                        frame_num_dic= {}
                        train_dataset.transform=tr.ToTensor()
#                        train_dataset.transform=tr.ToTensor()
                        trainloader = DataLoader(train_dataset,batch_size=1,
                            sampler = RandomIdentitySampler(train_dataset.sample_list), 
                            shuffle=False,num_workers=1,pin_memory=True)
                        
                        self.model.eval()
                        with torch.no_grad():
                            for ii, sample in enumerate(trainloader):
                                ref_imgs = sample['ref_img'] #batch_size * 3 * h * w
                                img1s = sample['img1'] 
                                img2s = sample['img2']
                                ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w
                                label1s = sample['label1']
                                label2s = sample['label2']
                                seq_names = sample['meta']['seq_name'] 
                                obj_nums = sample['meta']['obj_num']
                                frame_nums = sample['meta']['frame_num']
                                bs,_,h,w = img2s.size()
                                
 #                               inputs=torch.cat((ref_imgs,img1s,img2s),0)
#                                if r==0:
#                                    ref_scribble_labels=self.rough_ROI(ref_scribble_labels)
                                print(seq_names[0])
                                label2s_ = mask_damager(label2s,0.1)
                                # inputs = inputs.cuda()
                                # ref_scribble_labels=ref_scribble_labels.cuda()
                                # label1s = label1s.cuda()
                                
                                  
                                        
                                # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,
                                #                                 gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic,
                                #                                frame_num=frame_nums)    

                                #label2s_show = label2s.squeeze().numpy()
                                #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P')
                                #label2s_im.putpalette(_palette)
                                #label2s_im.save('label.png')
                                #label2s__show = label2s_
                                #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P')
                                #label2s_im_.putpalette(_palette)
                                #label2s_im_.save('damage_label.png')
                        #        exit()
                                
                                                            
                                #print(label2s_)
                                #print(label2s.size())
                                round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums)
                                label2s__=torch.from_numpy(label2s_)
                                # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg'))
                                # img_ww=np.array(img_ww)
                                # or_h,or_w = img_ww.shape[:2]
                                # label2s__=label2s__.unsqueeze(0).unsqueeze(0)        
                                # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest')
                                # label2s__=label2s__.squeeze(0)
#                                print(label2s__.size())

                                frame_num_dic[seq_names[0]]=frame_nums[0]
                                prev_round_label_dic[seq_names[0]]=label2s__    
    

                                #torch.cuda.empty_cache()     
                        print('trainset evaluating finished!')
                        print('*'*100)
                        print('updating ref frame and label')    

                        train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic)
                        self.model.train()
                        train_dataset.transform=composed_transforms 
                        print('updating ref frame and label finished!')