def init_ref_frame_dic(self): self.ref_frame_dic={} scribbles_path=os.path.join(self.db_root_dir,'Scribbles') for seq in self.seqs: selected_json = np.random.choice(['001.json','002.json','003.json'],1) selected_json = selected_json[0] scribble=os.path.join(self.db_root_dir,'Scribbles',seq,selected_json) with open(scribble) as f: scribble=json.load(f) # print(scribble) scr_frame=annotated_frames(scribble)[0] scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg') ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) ######################## ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png') ######################## # print(scribble_masks) scribble_label=scribble_masks[scr_frame] self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame}
def update_ref_frame_and_label(self,round_scribble=None,frame_num=None,prev_round_label_dic=None): ##########Update reference frame and scribbles for seq in self.seqs: scribble = round_scribble[seq] if frame_num is None: scr_frame=annotated_frames(scribble)[0] else: scr_frame= frame_num[seq] scr_frame = int(scr_frame) scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg') ####################### ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png') ######################### ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) if frame_num is None: scribble_label=scribble_masks[scr_frame] else: scribble_label=scribble_masks[0] self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame} if prev_round_label_dic is not None: self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame,'prev_round_label':prev_round_label_dic[seq]}
def Run_interaction(self, scribbles): # convert davis scribbles to torch target = scribbles['annotated_frame'] scribble_mask = scribbles2mask(scribbles, (self.height, self.width))[target] scribble_mask = Dilate_mask(scribble_mask, 1) self.tar_P, self.tar_N = ToCudaPN(scribble_mask) self.all_E[:,target], _, self.ref = self.model_I(self.all_F[:,:,target], self.all_E[:,target], self.tar_P, self.tar_N, self.dummy_M, [1,0,0,0,0]) # [batch, 256,512,2] print('[MODEL: interaction network] User Interaction on {}'.format(target))
def Run_interaction(self, scribbles): target = scribbles['annotated_frame'] print('height:{} width:{} target:{}'.format(self.height, self.width, target)) scribble_mask = scribbles2mask(scribbles, (self.height, self.width))[target] # 图片本来的宽高 scribble_mask = Dilate_mask(scribble_mask, 1) self.tar_P, self.tar_N = ToCudaPN(scribble_mask) with torch.no_grad(): self.all_E[:, target], _, self.ref = self.model_I(self.all_F[:, :, target], self.all_E[:, target], self.tar_P, self.tar_N, self.dummy_M, [1, 0, 0, 0, 0]) # [batch, 256,512,2] print('[MODEL: interaction network] User Interaction on {}'.format(target))
def update_ref_frame_and_label(self,round_scribble): ##########Update reference frame and scribbles scribble = round_scribble scr_frame=annotated_frames(scribble)[0] scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',self.seq_name,scr_f+'.jpg') ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) scribble_label=scribble_masks[scr_frame] self.ref_frame_dic={'ref_frame':ref_frame_path,'scribble_label':scribble_label}
def to_mask(self, scribble): # First we select the only frame with scribble all_scr = scribble['scribbles'] for idx, s in enumerate(all_scr): if len(s) != 0: scribble['scribbles'] = [s] break # Pass to DAVIS to change the path to an array scr_mask = scribbles2mask(scribble, (self.h, self.w))[0] # Run our S2M kernel = np.ones((3, 3), np.uint8) mask = torch.zeros((self.k, 1, self.nh, self.nw), dtype=torch.float32, device=self.device) for ki in range(1, self.k + 1): p_srb = (scr_mask == ki).astype(np.uint8) p_srb = cv2.dilate(p_srb, kernel).astype(np.bool) n_srb = ((scr_mask != ki) * (scr_mask != -1)).astype(np.uint8) n_srb = cv2.dilate(n_srb, kernel).astype(np.bool) Rs = torch.from_numpy(np.stack( [p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device) Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:]) # Use hard mask because we train S2M with such inputs = torch.cat([ self.processor.get_image_buffered(idx), (self.processor.masks[idx] == ki).to( self.device).float().unsqueeze(0), Rs ], 1) mask[ki - 1] = torch.sigmoid(self.s2m_net(inputs)) mask = aggregate_wbg(mask, keep_bg=True, hard=True) return mask, idx
def _init_ref_frame_dic(self): self.ref_frame_dic = {} scribbles_path = os.path.join(self.db_root_dir, 'Scribbles') for seq in self.seqs: scribble = os.path.join(self.db_root_dir, 'Scribbles', seq, '001.json') with open(scribble) as f: scribble = json.load(f) scr_frame = annotated_frames(scribble)[0] scr_f = str(scr_frame) while len(scr_f) != 5: scr_f = '0' + scr_f ref_frame_path = os.path.join('JPEGImages/480p', seq, scr_f + '.jpg') ref_tmp = cv2.imread( os.path.join(self.db_root_dir, ref_frame_path)) h_, w_ = ref_tmp.shape[:2] scribble_masks = scribbles2mask(scribble, (h_, w_)) scribble_label = scribble_masks[scr_frame] self.ref_frame_dic[seq] = { 'ref_frame': ref_frame_path, 'scribble_label': scribble_label }
def main(): total_frame_num_dic={} ################# seqs = [] with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seqs.extend(seqs_tmp) for seq_name in seqs: images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip()))) total_frame_num_dic[seq_name]=len(images) _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'v_a_l' + '_instances.txt') seq_dict = json.load(open(_seq_list_file, 'r')) ################## is_save_image=True report_save_dir= cfg.RESULT_ROOT save_res_dir = './saved_model_inter_net_new_re' # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions max_time_per_interaction = 10000 # Maximum time per interaction per object # Total time available to interact with a sequence and an initial set of scribbles max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object # Interactive parameters subset = 'val' host = 'localhost' # 'localhost' for subsets train and val. feature_extracter = DeepLab(backbone='resnet',freeze_bn=False) model = IntVOS(cfg,feature_extracter) model= model.cuda() print('model loading...') saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth') pretrained_dict = torch.load(saved_model_dict) load_network(model,pretrained_dict) print('model loading finished!') model.eval() seen_seq={} with torch.no_grad(): with DavisInteractiveSession(host=host, davis_root=cfg.DATA_ROOT, subset=subset, report_save_dir=report_save_dir, max_nb_interactions=max_nb_interactions, max_time=max_time, metric_to_optimize='J' ) as sess: while sess.next(): t_total=timeit.default_timer() # Get the current iteration scribbles sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True) print(sequence) start_annotated_frame=annotated_frames(scribbles)[0] pred_masks=[] pred_masks_reverse=[] if first_scribble: anno_frame_list=[] n_interaction=1 eval_global_map_tmp_dic={} local_map_dics=({},{}) total_frame_num=total_frame_num_dic[sequence] obj_nums = seq_dict[sequence][-1] eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(), seq_name=sequence) else: n_interaction+=1 ##########################Reference image process scr_f = start_annotated_frame anno_frame_list.append(start_annotated_frame) print(start_annotated_frame) scr_f = str(scr_f) while len(scr_f)!=5: scr_f='0'+scr_f ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg') ref_img = cv2.imread(ref_img) h_,w_ = ref_img.shape[:2] ref_img = np.array(ref_img,dtype=np.float32) #ref_img = tr.ToTensor()(ref_img) #ref_img = ref_img.unsqueeze(0) scribble_masks=scribbles2mask(scribbles,(h_,w_)) scribble_label=scribble_masks[start_annotated_frame] sample = {'ref_img':ref_img,'scribble_label':scribble_label} sample = tr.ToTensor()(sample) ref_img = sample['ref_img'] scribble_label = sample['scribble_label'] ref_img = ref_img.unsqueeze(0) scribble_label =scribble_label.unsqueeze(0) ref_img= ref_img.cuda() scribble_label = scribble_label.cuda() ###### ref_scribble_to_show = scribble_label.cpu().squeeze().numpy() im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P') im_.putpalette(_palette) ref_img_name= scr_f if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png')) #ref_frame_embedding = model.extract_feature(ref_img) if first_scribble: ref_frame_embedding = model.extract_feature(ref_img) _,channel,emb_h,emb_w = ref_frame_embedding.size() embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w)) embedding_memory = embedding_memory.cuda() embedding_memory[start_annotated_frame]=ref_frame_embedding[0] else: ref_frame_embedding = embedding_memory[start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## if first_scribble: prev_label= None else: prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png') prev_label = Image.open(prev_label) prev_label = np.array(prev_label,dtype=np.uint8) prev_label = tr.ToTensor()({'label':prev_label}) prev_label = prev_label['label'].unsqueeze(0) prev_label = prev_label.cuda() ############### #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding, # ref_frame_embedding,scribble_label,prev_label, # normalize_nearest_neighbor_distances=True,use_local_map=False, # seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, # global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead) tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_img = ref_img prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame+1,total_frame_num): print('evaluating sequence:{} frame:{}'.format(sequence,ii)) sample = eval_data_manager.get_image(ii) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() # current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[ii]=current_embedding[0] else: current_embedding=embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics, interaction_num=n_interaction,start_annotated_frame=start_annotated_frame, frame_num=[ii],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(ii) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### prev_label = ref_prev_label prev_img = ref_img prev_embedding = ref_frame_embedding ###################################### for ii in range(start_annotated_frame): current_frame_num=start_annotated_frame-1-ii print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num)) sample = eval_data_manager.get_image(current_frame_num) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() #current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[current_frame_num]=current_embedding[0] else: current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks_reverse.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = torch.cat(pred_masks_reverse,0) sess.submit_masks(final_masks.cpu().numpy()) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) report = sess.get_report() summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))
def test_step(self, weights, parallel=True, is_save_image=True, **cfg): # 1. Construct model. cfg['MODEL'].head.pretrained = '' cfg['MODEL'].head.test_mode = True model = build_model(cfg['MODEL']) if parallel: model = paddle.DataParallel(model) # 2. Construct data. sequence = cfg["video_path"].split('/')[-1].split('.')[0] obj_nums = 1 images, _ = load_video(cfg["video_path"], 480) print("stage1 load_video success") # [195, 389, 238, 47, 244, 374, 175, 399] # .shape: (502, 480, 600, 3) report_save_dir = cfg.get("output_dir", f"./output/{cfg['model_name']}") if not os.path.exists(report_save_dir): os.makedirs(report_save_dir) # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions # Interactive parameters model.eval() state_dicts_ = load(weights)['state_dict'] state_dicts = {} for k, v in state_dicts_.items(): if 'num_batches_tracked' not in k: state_dicts['head.' + k] = v if ('head.' + k) not in model.state_dict().keys(): print(f'pretrained -----{k} -------is not in model') write_dict(state_dicts, 'model_for_infer.txt', **cfg) model.set_state_dict(state_dicts) inter_file = open( os.path.join( cfg.get("output_dir", f"./output/{cfg['model_name']}"), 'inter_file.txt'), 'w') seen_seq = False with paddle.no_grad(): # Get the current iteration scribbles for scribbles, first_scribble in get_scribbles(): t_total = timeit.default_timer() f, h, w = images.shape[:3] if 'prev_label_storage' not in locals().keys(): prev_label_storage = paddle.zeros([f, h, w]) if len(annotated_frames(scribbles)) == 0: final_masks = prev_label_storage # ToDo To AP-kai: save_path传过来了 submit_masks(cfg["save_path"], final_masks.numpy(), images) continue # if no scribbles return, keep masks in previous round start_annotated_frame = annotated_frames(scribbles)[0] pred_masks = [] pred_masks_reverse = [] if first_scribble: # If in the first round, initialize memories n_interaction = 1 eval_global_map_tmp_dic = {} local_map_dics = ({}, {}) total_frame_num = f else: n_interaction += 1 inter_file.write(sequence + ' ' + 'interaction' + str(n_interaction) + ' ' + 'frame' + str(start_annotated_frame) + '\n') if first_scribble: # if in the first round, extract pixel embbedings. if not seen_seq: seen_seq = True inter_turn = 1 embedding_memory = [] places = paddle.set_device('cpu') for imgs in images: if cfg['PIPELINE'].get('test'): imgs = paddle.to_tensor([ build_pipeline(cfg['PIPELINE'].test)({ 'img1': imgs })['img1'] ]) else: imgs = paddle.to_tensor([imgs]) if parallel: for c in model.children(): frame_embedding = c.head.extract_feature( imgs) else: frame_embedding = model.head.extract_feature( imgs) embedding_memory.append(frame_embedding) del frame_embedding embedding_memory = paddle.concat(embedding_memory, 0) _, _, emb_h, emb_w = embedding_memory.shape ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) else: inter_turn += 1 ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) else: ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## scribble_masks = scribbles2mask(scribbles, (emb_h, emb_w)) scribble_label = scribble_masks[start_annotated_frame] scribble_sample = {'scribble_label': scribble_label} scribble_sample = ToTensor_manet()(scribble_sample) # print(ref_frame_embedding, ref_frame_embedding.shape) scribble_label = scribble_sample['scribble_label'] scribble_label = scribble_label.unsqueeze(0) model_name = cfg['model_name'] output_dir = cfg.get("output_dir", f"./output/{model_name}") inter_file_path = os.path.join( output_dir, sequence, 'interactive' + str(n_interaction), 'turn' + str(inter_turn)) if is_save_image: ref_scribble_to_show = scribble_label.squeeze().numpy() im_ = Image.fromarray( ref_scribble_to_show.astype('uint8')).convert('P', ) im_.putpalette(_palette) ref_img_name = str(start_annotated_frame) if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im_.save( os.path.join(inter_file_path, 'inter_' + ref_img_name + '.png')) if first_scribble: prev_label = None prev_label_storage = paddle.zeros([f, h, w]) else: prev_label = prev_label_storage[start_annotated_frame] prev_label = prev_label.unsqueeze(0).unsqueeze(0) # check if no scribbles. if not first_scribble and paddle.unique( scribble_label).shape[0] == 1: print( 'not first_scribble and paddle.unique(scribble_label).shape[0] == 1' ) print(paddle.unique(scribble_label)) final_masks = prev_label_storage submit_masks(cfg["save_path"], final_masks.numpy(), images) continue ###inteaction segmentation head if parallel: for c in model.children(): tmp_dic, local_map_dics = c.head.int_seghead( ref_frame_embedding=ref_frame_embedding, ref_scribble_label=scribble_label, prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), frame_num=[start_annotated_frame], first_inter=first_scribble) else: tmp_dic, local_map_dics = model.head.int_seghead( ref_frame_embedding=ref_frame_embedding, ref_scribble_label=scribble_label, prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), frame_num=[start_annotated_frame], first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks.append(float_(pred_label)) # np.unique(pred_label) # array([0], dtype=int64) prev_label_storage[start_annotated_frame] = float_( pred_label[0]) if is_save_image: # save image pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) ####################################### if first_scribble: scribble_label = rough_ROI(scribble_label) ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame + 1, total_frame_num): current_embedding = embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_label = prev_label if parallel: for c in model.children(): tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[ii], dynamic_seghead=c.head.dynamic_seghead) else: tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[ii], dynamic_seghead=model.head.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks.append(float_(pred_label)) prev_label = pred_label.unsqueeze(0) prev_embedding = current_embedding prev_label_storage[ii] = float_(pred_label[0]) if is_save_image: pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(ii) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) ####################################### prev_label = ref_prev_label prev_embedding = ref_frame_embedding ####### # Propagation <- for ii in range(start_annotated_frame): current_frame_num = start_annotated_frame - 1 - ii current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_label = prev_label if parallel: for c in model.children(): tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[current_frame_num], dynamic_seghead=c.head.dynamic_seghead) else: tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[current_frame_num], dynamic_seghead=model.head.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks_reverse.append(float_(pred_label)) prev_label = pred_label.unsqueeze(0) prev_embedding = current_embedding #### prev_label_storage[current_frame_num] = float_( pred_label[0]) ### if is_save_image: pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = paddle.concat(pred_masks_reverse, 0) submit_masks(cfg["save_path"], final_masks.numpy(), images) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) inter_file.close() return None
def Run(self, variables): all_F = variables['all_F'] num_objects = variables['info']['num_objs'] num_frames = variables['info']['num_frames'] height = variables['info']['height'] width = variables['info']['width'] prev_targets = variables['prev_targets'] scribbles = variables['scribbles'] target = scribbles['annotated_frame'] loss = 0 masks = torch.zeros(num_objects, num_frames, height, width) for n_obj in range(1, num_objects + 1): # variables for current obj all_E_n = variables['mask_objs'][n_obj-1:n_obj].data if variables['mask_objs'][n_obj-1:n_obj] is not None \ else variables['mask_objs'][n_obj-1:n_obj] a_ref = variables['ref'][n_obj - 1] prev_E_n = all_E_n.clone() all_M_n = (variables['all_M'] == n_obj).long() # divide scribbles for current object n_scribble = copy.deepcopy(scribbles) n_scribble['scribbles'][target] = [] for p in scribbles['scribbles'][target]: if p['object_id'] == n_obj: n_scribble['scribbles'][target].append(copy.deepcopy(p)) n_scribble['scribbles'][target][-1]['object_id'] = 1 else: if p['object_id'] == 0: n_scribble['scribbles'][target].append( copy.deepcopy(p)) scribble_mask = scribbles2mask(n_scribble, (height, width))[target] scribble_mask_N = (prev_E_n[0, target].cpu() > 0.5) & (torch.tensor(scribble_mask) == 0) scribble_mask[scribble_mask == 0] = -1 scribble_mask[scribble_mask_N] = 0 scribble_mask = Dilate_mask(scribble_mask, 1) # interaction tar_P, tar_N = ToCudaPN(scribble_mask) all_E_n[:, target], batch_CE, ref = self.model_I( all_F[:, :, target], all_E_n[:, target], tar_P, tar_N, all_M_n[:, target], [1, 1, 1, 1, 1]) # [batch, 256,512,2] loss += batch_CE # propagation left_end, right_end, weight = Get_weight(target, prev_targets, num_frames, at_least=-1) # Prop_forward next_a_ref = None for n in range(target + 1, right_end + 1): #[1,2,...,N-1] all_E_n[:, n], batch_CE, next_a_ref = self.model_P( ref, a_ref, all_F[:, :, n], prev_E_n[:, n], all_E_n[:, n - 1], all_M_n[:, n], [1, 1, 1, 1, 1], next_a_ref) loss += batch_CE # Prop_backward for n in reversed(range(left_end, target)): all_E_n[:, n], batch_CE, next_a_ref = self.model_P( ref, a_ref, all_F[:, :, n], prev_E_n[:, n], all_E_n[:, n + 1], all_M_n[:, n], [1, 1, 1, 1, 1], next_a_ref) loss += batch_CE for f in range(num_frames): all_E_n[:, f, :, :] = weight[f] * all_E_n[:, f, :, :] + ( 1 - weight[f]) * prev_E_n[:, f, :, :] masks[n_obj - 1] = all_E_n[0] variables['ref'][n_obj - 1] = next_a_ref loss /= num_objects em = torch.zeros(1, num_objects + 1, num_frames, height, width).to(masks.device) em[0, 0, :, :, :] = torch.prod(1 - masks, dim=0) # bg prob em[0, 1:num_objects + 1, :, :] = masks # obj prob em = torch.clamp(em, 1e-7, 1 - 1e-7) all_E = torch.log((em / (1 - em))) all_E = F.softmax(all_E, dim=1) final_masks = all_E[0].max(0)[1].float() variables['prev_targets'].append(target) variables['masks'] = final_masks variables['mask_objs'] = masks variables['probs'] = all_E variables['loss'] = loss