def test_VOS(self,use_gpu=True): seqs = [] with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seqs.extend(seqs_tmp) print('model loading...') saved_model_dict = os.path.join(self.save_res_dir,'save_step_60000.pth') pretrained_dict = torch.load(saved_model_dict) self.model = self.load_network(self.model,pretrained_dict) print('model load finished') self.model.eval() with torch.no_grad(): for seq_name in seqs: print('prcessing seq:{}'.format(seq_name)) test_dataset =DAVIS2017_VOS_Test(root = cfg.DATA_ROOT,transform=tr.ToTensor(),result_root=cfg.RESULT_ROOT,seq_name=seq_name) test_dataloader=DataLoader(test_dataset,batch_size=1, shuffle=False,num_workers=0,pin_memory=True) if not os.path.exists(os.path.join(cfg.RESULT_ROOT,seq_name)): os.makedirs(os.path.join(cfg.RESULT_ROOT,seq_name)) time_start= time.time() for ii,sample in enumerate(test_dataloader): ref_img = sample['ref_img'] prev_img = sample['prev_img'] current_img = sample['current_img'] ref_label = sample['ref_label'] prev_label =sample['prev_label'] obj_num = sample['meta']['obj_num'] seqnames = sample['meta']['seq_name'] imgname = sample['meta']['current_name'] bs,_,h,w = current_img.size() inputs = torch.cat((ref_img,prev_img,current_img),0) if use_gpu: inputs = inputs.cuda() ref_label=ref_label.cuda() prev_label=prev_label.cuda() ## tmp_dic = self.model(inputs,ref_label,prev_label,seq_names=seqnames,gt_ids=obj_num,k_nearest_neighbors=cfg.KNNS) ################ t1 = time.time() tmp = self.model.extract_feature(inputs) ref_frame_embedding, previous_frame_embedding,current_frame_embedding = torch.split(tmp,split_size_or_sections=int(tmp.size(0)/3), dim=0) t2 = time.time() print('feature_extracter time:{}'.format(t2-t1)) tmp_dic = self.model.before_seghead_process(ref_frame_embedding,previous_frame_embedding, current_frame_embedding,ref_label,prev_label, True, seqnames,obj_num,cfg.KNNS,self.model.dynamic_seghead) t3 = time.time() print('after time:{}'.format(t3-t2)) ####################### pred_label = tmp_dic[seq_name].cpu() pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_label = pred_label.squeeze(0) #pred_label = pred_label.permute(1,2,0) pred_label=pred_label.numpy() im = Image.fromarray(pred_label.astype('uint8')).convert('P') im.putpalette(_palette) im.save(os.path.join(cfg.RESULT_ROOT,seq_name,imgname[0].split('.')[0]+'.png')) # cv2.imwrite(os.path.join('./result_2',seq_name,imgname[0].split('.')[0]+'.png'),pred_label) one_frametime= time.time() print('seq name:{} frame:{} time:{}'.format(seq_name,imgname[0],one_frametime-time_start)) time_start=time.time()
def train_VOS(self,damage_initial_previous_frame_mask=False,lossfunc='cross_entropy',model_resume=False): ################### self.model.train() running_loss = AverageMeter() optimizer = optim.SGD(self.model.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') # train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # sampler = RandomIdentitySampler(train_dataset.sample_list), # shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_30000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=30000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: # sample['meta']={'seq_name':seqname,'frame_num':frame_num,'obj_num':obj_num} for ii, sample in enumerate(trainloader): now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if damage_initial_previous_frame_mask: try: label1s = damage_masks(label1s) except: label1s = label1s print('damage_error') ########## if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() label2s = label2s.cuda() ########## tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(label2s[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(1,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%50==0: #print(torch.cuda.memory_allocated()) print(torch.cuda.max_memory_cached()) torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] show_img1 = img1s.cpu().numpy()[0] show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean show_img1 = show_img1*sigma+mean show_img2 = show_img2*sigma+mean show_gt = label2s.cpu()[0] show_gt = show_gt.squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = (torch.sigmoid(show_preds)>0.5) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): show_preds_s[show_preds[i]]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) tblogger.add_image('Previous frame image', show_img1, step) tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) ###########TODO if step%5000==0 and step!=0: self.save_network(self.model,step) step+=1
def train(self,damage_initial_previous_frame_mask=True,lossfunc='cross_entropy',model_resume=False,eval_total=False,init_prev=False): ################### interactor = interactive_robot.InteractiveScribblesRobot() self.model.train() running_loss = AverageMeter() optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM,weight_decay=cfg.TRAIN_WEIGHT_DECAY) # optimizer = optim.SGD(self.model.inter_seghead.parameters(),lr=cfg.TRAIN_LR,momentum=cfg.TRAIN_MOMENTUM) #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=cfg.TRAIN_LR_STEPSIZE,gamma=cfg.TRAIN_LR_GAMMA) ################### composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP),10), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) print('dataset processing...') train_dataset = DAVIS2017_Train(root=cfg.DATA_ROOT, transform=composed_transforms) train_list = train_dataset.seqs # train_dataset = DAVIS2017_VOS_Train(root=cfg.DATA_ROOT, transform=composed_transforms) # trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, # shuffle=True,num_workers=cfg.NUM_WORKER,pin_memory=True) print('dataset processing finished.') if lossfunc=='bce': criterion = Added_BCEWithLogitsLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) elif lossfunc=='cross_entropy': criterion = Added_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS,cfg.TRAIN_HARD_MINING_STEP) else: print('unsupported loss funciton. Please choose from [cross_entropy,bce]') max_itr = cfg.TRAIN_TOTAL_STEPS step=0 round_=3 epoch_per_round=30 if model_resume: saved_model_=os.path.join(self.save_res_dir,'save_step_75000.pth') saved_model_ = torch.load(saved_model_) self.model=self.load_network(self.model,saved_model_) step=75000 print('resume from step {}'.format(step)) while step<cfg.TRAIN_TOTAL_STEPS: if step>100001: break for r in range(round_): if r==0: print('start new') global_map_tmp_dic={} train_dataset.transform=transforms.Compose([tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), tr.RandomScale(), tr.RandomCrop((cfg.DATA_RANDOMCROP,cfg.DATA_RANDOMCROP)), tr.Resize(cfg.DATA_RESCALE), tr.ToTensor()]) train_dataset.init_ref_frame_dic() trainloader = DataLoader(train_dataset,batch_size=cfg.TRAIN_BATCH_SIZE, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) print('round:{} start'.format(r)) for epoch in range(epoch_per_round): for ii, sample in enumerate(trainloader): now_lr=self._adjust_lr(optimizer,step,max_itr) ref_imgs = sample['ref_img'] #batch_size * 3 * h * w #img1s = sample['img1'] #img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w #label1s = sample['label1'] #label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] #frame_nums = sample['meta']['frame_num'] ref_frame_nums = sample['meta']['ref_frame_num'] ref_frame_gts=sample['ref_frame_gt'] bs,_,h,w = ref_imgs.size() # print(ref_imgs.size()) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) ########## if self.use_gpu: inputs = ref_imgs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() ref_frame_gts = ref_frame_gts.cuda() #label1s = label1s.cuda() #label2s = label2s.cuda() #print(inputs.size()) ########## with torch.no_grad(): self.model.feature_extracter.eval() self.model.semantic_embedding.eval() ref_frame_embedding = self.model.extract_feature(inputs) if r==0: first_inter=True tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=None,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) else: first_inter=False prev_round_label=sample['prev_round_label'] # print(prev_round_label.size()) #prev_round_label=prev_round_label_dic[seq_names[0]] prev_round_label=prev_round_label.cuda() tmp_dic = self.model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=ref_scribble_labels, prev_round_label=prev_round_label,normalize_nearest_neighbor_distances=True,global_map_tmp_dic={}, seq_names=seq_names,gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS, frame_num=ref_frame_nums,first_inter=first_inter) label_and_obj_dic={} label_dic={} for i, seq_ in enumerate(seq_names): label_and_obj_dic[seq_]=(ref_frame_gts[i],obj_nums[i]) for seq_ in tmp_dic.keys(): tmp_pred_logits = tmp_dic[seq_] tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits,size=(h,w),mode = 'bilinear',align_corners=True) tmp_dic[seq_]=tmp_pred_logits label_tmp,obj_num = label_and_obj_dic[seq_] obj_ids = np.arange(0,obj_num+1) obj_ids = torch.from_numpy(obj_ids) obj_ids = obj_ids.int() if torch.cuda.is_available(): obj_ids = obj_ids.cuda() if lossfunc == 'bce': label_tmp = label_tmp.permute(1,2,0) label = (label_tmp.float()==obj_ids.float()) label = label.unsqueeze(-1).permute(3,2,0,1) label_dic[seq_]=label.float() elif lossfunc =='cross_entropy': label_dic[seq_]=label_tmp.long() loss = criterion(tmp_dic,label_dic,step) loss =loss/bs optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() running_loss.update(loss.item(),bs) if step%50==0: #print(torch.cuda.memory_allocated()) #print(torch.cuda.max_memory_cached()) torch.cuda.empty_cache() #torch.cuda.reset_max_memory_allocated() print('step:{},now_lr:{} ,loss:{:.4f}({:.4f})'.format(step,now_lr ,running_loss.val,running_loss.avg)) # print('step:{}'.format(step)) show_ref_img = ref_imgs.cpu().numpy()[0] #show_img1 = img1s.cpu().numpy()[0] #show_img2 = img2s.cpu().numpy()[0] mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) show_ref_img = show_ref_img*sigma+mean #show_img1 = show_img1*sigma+mean #show_img2 = show_img2*sigma+mean #show_gt = label2s.cpu()[0] show_gt = ref_frame_gts.cpu()[0].squeeze(0).numpy() show_gtf = label2colormap(show_gt).transpose((2,0,1)) show_scrbble=ref_scribble_labels.cpu()[0].squeeze(0).numpy() show_scrbble=label2colormap(show_scrbble).transpose((2,0,1)) if r!=0: show_prev_round_label=prev_round_label.cpu()[0].squeeze(0).numpy() show_prev_round_label=label2colormap(show_prev_round_label).transpose((2,0,1)) else: show_prev_round_label = np.zeros_like(show_gt) show_prev_round_label = label2colormap(show_prev_round_label).transpose((2,0,1)) ########## show_preds = tmp_dic[seq_names[0]].cpu() show_preds=nn.functional.interpolate(show_preds,size=(h,w),mode = 'bilinear',align_corners=True) show_preds = show_preds.squeeze(0) if lossfunc=='bce': show_preds = show_preds[1:] show_preds = (torch.sigmoid(show_preds)>0.5) marker = torch.argmax(show_preds,dim=0) show_preds_s = torch.zeros((h,w)) for i in range(show_preds.size(0)): tmp_mask = (marker==i) & (show_preds[i]>0.5) show_preds_s[tmp_mask]=i+1 elif lossfunc=='cross_entropy': show_preds_s = torch.argmax(show_preds,dim=0) show_preds_s = show_preds_s.numpy() show_preds_sf = label2colormap(show_preds_s).transpose((2,0,1)) pix_acc = np.sum(show_preds_s==show_gt)/(h*w) if cfg.TRAIN_TBLOG: tblogger.add_scalar('loss', running_loss.avg, step) tblogger.add_scalar('pix_acc', pix_acc, step) tblogger.add_scalar('now_lr', now_lr, step) tblogger.add_image('Reference image', show_ref_img, step) #tblogger.add_image('Previous frame image', show_img1, step) # tblogger.add_image('Current frame image', show_img2, step) tblogger.add_image('Groud Truth', show_gtf, step) tblogger.add_image('Predict label', show_preds_sf, step) tblogger.add_image('Scribble', show_scrbble, step) tblogger.add_image('prev_round_label', show_prev_round_label, step) ###########TODO if step%20000==0 and step!=0: self.save_network(self.model,step) step+=1 print('trainset evaluating...') print('*'*100) if cfg.TRAIN_INTER_USE_TRUE_RESULT: if r!=round_-1: if r ==0: prev_round_label_dic={} self.model.eval() with torch.no_grad(): round_scribble={} frame_num_dic= {} train_dataset.transform=transforms.Compose([tr.Resize(cfg.DATA_RESCALE),tr.ToTensor()]) # train_dataset.transform=composed_transforms trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=cfg.NUM_WORKER,pin_memory=True) for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() inputs = torch.cat((ref_imgs,img1s,img2s),0) if r==0: ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) # if damage_initial_previous_frame_mask: # try: # label1s = damage_masks(label1s) # except: # label1s = label1s # print('damage_error') label1s_tocat=torch.Tensor() for i in range(bs): l = label1s[i] l = l.unsqueeze(0) l = mask_damager(l,0.0) l = torch.from_numpy(l) l = l.unsqueeze(0).unsqueeze(0) label1s_tocat = torch.cat((label1s_tocat,l.float()),0) label1s = label1s_tocat if self.use_gpu: inputs = inputs.cuda() ref_scribble_labels=ref_scribble_labels.cuda() label1s = label1s.cuda() tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, frame_num=frame_nums) pred_label = tmp_dic[seq_names[0]].detach().cpu() pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_label= pred_label.unsqueeze(0) try: pred_label=damage_masks(pred_label) except: pred_label=pred_label pred_label=pred_label.squeeze(0) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],pred_label.numpy(),label2s.float().squeeze(0).numpy(),obj_nums) frame_num_dic[seq_names[0]]=frame_nums[0] pred_label=pred_label.unsqueeze(0) img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) img_ww=np.array(img_ww) or_h,or_w = img_ww.shape[:2] pred_label = torch.nn.functional.interpolate(pred_label.float(),(or_h,or_w),mode='nearest') prev_round_label_dic[seq_names[0]]=pred_label.squeeze(0) # torch.cuda.empty_cache() train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) print('trainset evaluating finished!') print('*'*100) self.model.train() print('updating ref frame and label') train_dataset.transform=composed_transforms print('updating ref frame and label finished!') else: if r!=round_-1: round_scribble={} if r ==0: prev_round_label_dic={} # eval_global_map_tmp_dic={} frame_num_dic= {} train_dataset.transform=tr.ToTensor() # train_dataset.transform=tr.ToTensor() trainloader = DataLoader(train_dataset,batch_size=1, sampler = RandomIdentitySampler(train_dataset.sample_list), shuffle=False,num_workers=1,pin_memory=True) self.model.eval() with torch.no_grad(): for ii, sample in enumerate(trainloader): ref_imgs = sample['ref_img'] #batch_size * 3 * h * w img1s = sample['img1'] img2s = sample['img2'] ref_scribble_labels = sample['ref_scribble_label'] #batch_size * 1 * h * w label1s = sample['label1'] label2s = sample['label2'] seq_names = sample['meta']['seq_name'] obj_nums = sample['meta']['obj_num'] frame_nums = sample['meta']['frame_num'] bs,_,h,w = img2s.size() # inputs=torch.cat((ref_imgs,img1s,img2s),0) # if r==0: # ref_scribble_labels=self.rough_ROI(ref_scribble_labels) print(seq_names[0]) label2s_ = mask_damager(label2s,0.1) # inputs = inputs.cuda() # ref_scribble_labels=ref_scribble_labels.cuda() # label1s = label1s.cuda() # tmp_dic, global_map_tmp_dic = self.model(inputs,ref_scribble_labels,label1s,seq_names=seq_names, # gt_ids=obj_nums,k_nearest_neighbors=cfg.KNNS,global_map_tmp_dic=global_map_tmp_dic, # frame_num=frame_nums) #label2s_show = label2s.squeeze().numpy() #label2s_im = Image.fromarray(label2s_show.astype('uint8')).convert('P') #label2s_im.putpalette(_palette) #label2s_im.save('label.png') #label2s__show = label2s_ #label2s_im_ = Image.fromarray(label2s__show.astype('uint8')).convert('P') #label2s_im_.putpalette(_palette) #label2s_im_.save('damage_label.png') # exit() #print(label2s_) #print(label2s.size()) round_scribble[seq_names[0]]=interactor.interact(seq_names[0],np.expand_dims(label2s_,axis=0),label2s.float().squeeze(0).numpy(),obj_nums) label2s__=torch.from_numpy(label2s_) # img_ww=Image.open(os.path.join(cfg.DATA_ROOT,'JPEGImages/480p/',seq_names[0],'00000.jpg')) # img_ww=np.array(img_ww) # or_h,or_w = img_ww.shape[:2] # label2s__=label2s__.unsqueeze(0).unsqueeze(0) # label2s__ = torch.nn.functional.interpolate(label2s__.float(),(or_h,or_w),mode='nearest') # label2s__=label2s__.squeeze(0) # print(label2s__.size()) frame_num_dic[seq_names[0]]=frame_nums[0] prev_round_label_dic[seq_names[0]]=label2s__ #torch.cuda.empty_cache() print('trainset evaluating finished!') print('*'*100) print('updating ref frame and label') train_dataset.update_ref_frame_and_label(round_scribble,frame_num_dic,prev_round_label_dic) self.model.train() train_dataset.transform=composed_transforms print('updating ref frame and label finished!')
def main(): total_frame_num_dic={} ################# seqs = [] with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seqs.extend(seqs_tmp) for seq_name in seqs: images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip()))) total_frame_num_dic[seq_name]=len(images) _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'v_a_l' + '_instances.txt') seq_dict = json.load(open(_seq_list_file, 'r')) ################## is_save_image=True report_save_dir= cfg.RESULT_ROOT save_res_dir = './saved_model_inter_net_new_re' # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions max_time_per_interaction = 10000 # Maximum time per interaction per object # Total time available to interact with a sequence and an initial set of scribbles max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object # Interactive parameters subset = 'val' host = 'localhost' # 'localhost' for subsets train and val. feature_extracter = DeepLab(backbone='resnet',freeze_bn=False) model = IntVOS(cfg,feature_extracter) model= model.cuda() print('model loading...') saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth') pretrained_dict = torch.load(saved_model_dict) load_network(model,pretrained_dict) print('model loading finished!') model.eval() seen_seq={} with torch.no_grad(): with DavisInteractiveSession(host=host, davis_root=cfg.DATA_ROOT, subset=subset, report_save_dir=report_save_dir, max_nb_interactions=max_nb_interactions, max_time=max_time, metric_to_optimize='J' ) as sess: while sess.next(): t_total=timeit.default_timer() # Get the current iteration scribbles sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True) print(sequence) start_annotated_frame=annotated_frames(scribbles)[0] pred_masks=[] pred_masks_reverse=[] if first_scribble: anno_frame_list=[] n_interaction=1 eval_global_map_tmp_dic={} local_map_dics=({},{}) total_frame_num=total_frame_num_dic[sequence] obj_nums = seq_dict[sequence][-1] eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(), seq_name=sequence) else: n_interaction+=1 ##########################Reference image process scr_f = start_annotated_frame anno_frame_list.append(start_annotated_frame) print(start_annotated_frame) scr_f = str(scr_f) while len(scr_f)!=5: scr_f='0'+scr_f ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg') ref_img = cv2.imread(ref_img) h_,w_ = ref_img.shape[:2] ref_img = np.array(ref_img,dtype=np.float32) #ref_img = tr.ToTensor()(ref_img) #ref_img = ref_img.unsqueeze(0) scribble_masks=scribbles2mask(scribbles,(h_,w_)) scribble_label=scribble_masks[start_annotated_frame] sample = {'ref_img':ref_img,'scribble_label':scribble_label} sample = tr.ToTensor()(sample) ref_img = sample['ref_img'] scribble_label = sample['scribble_label'] ref_img = ref_img.unsqueeze(0) scribble_label =scribble_label.unsqueeze(0) ref_img= ref_img.cuda() scribble_label = scribble_label.cuda() ###### ref_scribble_to_show = scribble_label.cpu().squeeze().numpy() im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P') im_.putpalette(_palette) ref_img_name= scr_f if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png')) #ref_frame_embedding = model.extract_feature(ref_img) if first_scribble: ref_frame_embedding = model.extract_feature(ref_img) _,channel,emb_h,emb_w = ref_frame_embedding.size() embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w)) embedding_memory = embedding_memory.cuda() embedding_memory[start_annotated_frame]=ref_frame_embedding[0] else: ref_frame_embedding = embedding_memory[start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## if first_scribble: prev_label= None else: prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png') prev_label = Image.open(prev_label) prev_label = np.array(prev_label,dtype=np.uint8) prev_label = tr.ToTensor()({'label':prev_label}) prev_label = prev_label['label'].unsqueeze(0) prev_label = prev_label.cuda() ############### #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding, # ref_frame_embedding,scribble_label,prev_label, # normalize_nearest_neighbor_distances=True,use_local_map=False, # seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, # global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead) tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_img = ref_img prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame+1,total_frame_num): print('evaluating sequence:{} frame:{}'.format(sequence,ii)) sample = eval_data_manager.get_image(ii) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() # current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[ii]=current_embedding[0] else: current_embedding=embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics, interaction_num=n_interaction,start_annotated_frame=start_annotated_frame, frame_num=[ii],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(ii) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### prev_label = ref_prev_label prev_img = ref_img prev_embedding = ref_frame_embedding ###################################### for ii in range(start_annotated_frame): current_frame_num=start_annotated_frame-1-ii print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num)) sample = eval_data_manager.get_image(current_frame_num) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() #current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[current_frame_num]=current_embedding[0] else: current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks_reverse.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = torch.cat(pred_masks_reverse,0) sess.submit_masks(final_masks.cpu().numpy()) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) report = sess.get_report() summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))