def init_ref_frame_dic(self): self.ref_frame_dic={} scribbles_path=os.path.join(self.db_root_dir,'Scribbles') for seq in self.seqs: selected_json = np.random.choice(['001.json','002.json','003.json'],1) selected_json = selected_json[0] scribble=os.path.join(self.db_root_dir,'Scribbles',seq,selected_json) with open(scribble) as f: scribble=json.load(f) # print(scribble) scr_frame=annotated_frames(scribble)[0] scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg') ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) ######################## ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png') ######################## # print(scribble_masks) scribble_label=scribble_masks[scr_frame] self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame}
def update_ref_frame_and_label(self,round_scribble=None,frame_num=None,prev_round_label_dic=None): ##########Update reference frame and scribbles for seq in self.seqs: scribble = round_scribble[seq] if frame_num is None: scr_frame=annotated_frames(scribble)[0] else: scr_frame= frame_num[seq] scr_frame = int(scr_frame) scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg') ####################### ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png') ######################### ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) if frame_num is None: scribble_label=scribble_masks[scr_frame] else: scribble_label=scribble_masks[0] self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame} if prev_round_label_dic is not None: self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame,'prev_round_label':prev_round_label_dic[seq]}
def test_load_scribble(self): dataset_dir = Path(__file__).parent / 'test_data' / 'DAVIS' davis = Davis(dataset_dir) scribble = davis.load_scribble('bear', 1) assert scribble['sequence'] == 'bear' assert not is_empty(scribble) assert annotated_frames(scribble) == [39]
def test_integration_single(self, mock_davis): dataset_dir = Path(__file__).parent.joinpath('test_data', 'DAVIS') tmp_dir = Path(tempfile.mkdtemp()) with DavisInteractiveSession( davis_root=dataset_dir, subset='train', max_nb_interactions=5, report_save_dir=tmp_dir, max_time=None) as session: count = 0 temp_csv = tmp_dir / ("%s.tmp.csv" % session.report_name) final_csv = tmp_dir / ("%s.csv" % session.report_name) while session.next(): assert not final_csv.exists() assert temp_csv.exists() df = pd.read_csv(temp_csv, index_col=0) assert df.shape == (count * 2, 10) seq, scribble, new_seq = session.get_scribbles(only_last=True) assert new_seq == (count == 0) assert seq == 'bear' if count == 0: with dataset_dir.joinpath('Scribbles', 'bear', '001.json').open() as fp: sc = json.load(fp) assert sc == scribble else: assert annotated_frames(scribble) == [1] assert not is_empty(scribble) assert len(scribble['scribbles']) == 2 assert len(scribble['scribbles'][1]) > 0 assert len(scribble['scribbles'][0]) == 0 # Simulate model predicting masks pred_masks = np.zeros((2, 480, 854)) session.submit_masks( pred_masks, next_scribble_frame_candidates=[1]) if count > 0: assert df.sequence.unique() == ['bear'] assert np.all(df.interaction.unique() == [i + 1 for i in range(count)]) assert np.all(df.object_id.unique() == [1]) count += 1 assert count == 5 assert final_csv.exists() assert not temp_csv.exists() assert mock_davis.call_count == 0
def test_starting_scribble(self, _): dataset_dir = Path(__file__).parent.parent.joinpath( 'dataset', 'test_data', 'DAVIS') service = EvaluationService('train', davis_root=dataset_dir) service.get_samples() scribble = service.get_scribble('bear', 1) assert scribble['sequence'] == 'bear' assert not is_empty(scribble) assert annotated_frames(scribble) == [39]
def test_interaction_equal(self): nb_frames, h, w = 10, 300, 500 gt_empty = np.zeros((nb_frames, h, w), dtype=np.int) gt_empty[0, 100:200, 100:200] = 1 pred_empty = gt_empty.copy() robot = InteractiveScribblesRobot() scribble = robot.interact('test', pred_empty, gt_empty) assert is_empty(scribble) assert annotated_frames(scribble) == [] assert len(scribble['scribbles']) == nb_frames
def test_integration_single_only_last(self, mock_davis): dataset_dir = Path(__file__).parent.joinpath('test_data', 'DAVIS') with DavisInteractiveSession( davis_root=dataset_dir, subset='train', max_nb_interactions=4, report_save_dir=tempfile.mkdtemp(), max_time=None) as session: count = 0 annotated_frames_list = [] while session.next(): seq, scribble, new_seq = session.get_scribbles(only_last=True) assert new_seq == (count == 0) assert seq == 'blackswan' if count == 0: with dataset_dir.joinpath('Scribbles', 'blackswan', '001.json').open() as fp: sc = json.load(fp) assert sc == scribble else: assert len(annotated_frames(scribble)) == 1 a_fr = annotated_frames(scribble)[0] assert a_fr not in annotated_frames_list annotated_frames_list.append(a_fr) assert not is_empty(scribble) # Simulate model predicting masks pred_masks = np.zeros((6, 480, 854)) session.submit_masks(pred_masks) count += 1 assert count == 4 assert mock_davis.call_count == 0
def update_ref_frame_and_label(self,round_scribble): ##########Update reference frame and scribbles scribble = round_scribble scr_frame=annotated_frames(scribble)[0] scr_f=str(scr_frame) while len(scr_f)!=5: scr_f='0'+scr_f ref_frame_path=os.path.join('JPEGImages/480p',self.seq_name,scr_f+'.jpg') ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path)) h_,w_=ref_tmp.shape[:2] scribble_masks=scribbles2mask(scribble,(h_,w_)) scribble_label=scribble_masks[scr_frame] self.ref_frame_dic={'ref_frame':ref_frame_path,'scribble_label':scribble_label}
def test_interaction(self): nb_frames, h, w = 10, 300, 500 gt_empty = np.zeros((nb_frames, h, w), dtype=np.int) pred_empty = gt_empty.copy() gt_empty[5, 100:200, 100:200] = 1 robot = InteractiveScribblesRobot() scribble = robot.interact('test', pred_empty, gt_empty) assert not is_empty(scribble) assert annotated_frames(scribble) == [5] assert len(scribble['scribbles']) == nb_frames lines = scribble['scribbles'][5] for l in lines: assert l['object_id'] == 1 path = np.asarray(l['path']) x, y = path[:, 0], path[:, 1] assert np.all((x >= .2) & (x <= .4)) assert np.all((y >= 1 / 3) & (y <= 2 / 3))
def test_interaction_false_positive_single_frame(self): nb_frames, h, w = 1, 300, 500 gt_empty = np.zeros((nb_frames, h, w), dtype=np.int) pred_empty = np.ones((nb_frames, h, w), dtype=np.int) gt_empty[0, 100:200, 100:200] = 1 robot = InteractiveScribblesRobot() scribble = robot.interact('test', pred_empty, gt_empty) assert not is_empty(scribble) assert annotated_frames(scribble) == [0] assert len(scribble['scribbles']) == nb_frames lines = scribble['scribbles'][0] assert lines for l in lines: assert l['object_id'] == 0 path = np.asarray(l['path']) x, y = path[:, 0], path[:, 1] inside = (x >= .2) & (x <= .4) & (y >= 1 / 3) & (y <= 2 / 3) assert not np.any(inside)
def _init_ref_frame_dic(self): self.ref_frame_dic = {} scribbles_path = os.path.join(self.db_root_dir, 'Scribbles') for seq in self.seqs: scribble = os.path.join(self.db_root_dir, 'Scribbles', seq, '001.json') with open(scribble) as f: scribble = json.load(f) scr_frame = annotated_frames(scribble)[0] scr_f = str(scr_frame) while len(scr_f) != 5: scr_f = '0' + scr_f ref_frame_path = os.path.join('JPEGImages/480p', seq, scr_f + '.jpg') ref_tmp = cv2.imread( os.path.join(self.db_root_dir, ref_frame_path)) h_, w_ = ref_tmp.shape[:2] scribble_masks = scribbles2mask(scribble, (h_, w_)) scribble_label = scribble_masks[scr_frame] self.ref_frame_dic[seq] = { 'ref_frame': ref_frame_path, 'scribble_label': scribble_label }
def main(): total_frame_num_dic={} ################# seqs = [] with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seqs.extend(seqs_tmp) for seq_name in seqs: images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip()))) total_frame_num_dic[seq_name]=len(images) _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'v_a_l' + '_instances.txt') seq_dict = json.load(open(_seq_list_file, 'r')) ################## is_save_image=True report_save_dir= cfg.RESULT_ROOT save_res_dir = './saved_model_inter_net_new_re' # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions max_time_per_interaction = 10000 # Maximum time per interaction per object # Total time available to interact with a sequence and an initial set of scribbles max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object # Interactive parameters subset = 'val' host = 'localhost' # 'localhost' for subsets train and val. feature_extracter = DeepLab(backbone='resnet',freeze_bn=False) model = IntVOS(cfg,feature_extracter) model= model.cuda() print('model loading...') saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth') pretrained_dict = torch.load(saved_model_dict) load_network(model,pretrained_dict) print('model loading finished!') model.eval() seen_seq={} with torch.no_grad(): with DavisInteractiveSession(host=host, davis_root=cfg.DATA_ROOT, subset=subset, report_save_dir=report_save_dir, max_nb_interactions=max_nb_interactions, max_time=max_time, metric_to_optimize='J' ) as sess: while sess.next(): t_total=timeit.default_timer() # Get the current iteration scribbles sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True) print(sequence) start_annotated_frame=annotated_frames(scribbles)[0] pred_masks=[] pred_masks_reverse=[] if first_scribble: anno_frame_list=[] n_interaction=1 eval_global_map_tmp_dic={} local_map_dics=({},{}) total_frame_num=total_frame_num_dic[sequence] obj_nums = seq_dict[sequence][-1] eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(), seq_name=sequence) else: n_interaction+=1 ##########################Reference image process scr_f = start_annotated_frame anno_frame_list.append(start_annotated_frame) print(start_annotated_frame) scr_f = str(scr_f) while len(scr_f)!=5: scr_f='0'+scr_f ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg') ref_img = cv2.imread(ref_img) h_,w_ = ref_img.shape[:2] ref_img = np.array(ref_img,dtype=np.float32) #ref_img = tr.ToTensor()(ref_img) #ref_img = ref_img.unsqueeze(0) scribble_masks=scribbles2mask(scribbles,(h_,w_)) scribble_label=scribble_masks[start_annotated_frame] sample = {'ref_img':ref_img,'scribble_label':scribble_label} sample = tr.ToTensor()(sample) ref_img = sample['ref_img'] scribble_label = sample['scribble_label'] ref_img = ref_img.unsqueeze(0) scribble_label =scribble_label.unsqueeze(0) ref_img= ref_img.cuda() scribble_label = scribble_label.cuda() ###### ref_scribble_to_show = scribble_label.cpu().squeeze().numpy() im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P') im_.putpalette(_palette) ref_img_name= scr_f if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png')) #ref_frame_embedding = model.extract_feature(ref_img) if first_scribble: ref_frame_embedding = model.extract_feature(ref_img) _,channel,emb_h,emb_w = ref_frame_embedding.size() embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w)) embedding_memory = embedding_memory.cuda() embedding_memory[start_annotated_frame]=ref_frame_embedding[0] else: ref_frame_embedding = embedding_memory[start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## if first_scribble: prev_label= None else: prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png') prev_label = Image.open(prev_label) prev_label = np.array(prev_label,dtype=np.uint8) prev_label = tr.ToTensor()({'label':prev_label}) prev_label = prev_label['label'].unsqueeze(0) prev_label = prev_label.cuda() ############### #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding, # ref_frame_embedding,scribble_label,prev_label, # normalize_nearest_neighbor_distances=True,use_local_map=False, # seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, # global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead) tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_img = ref_img prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame+1,total_frame_num): print('evaluating sequence:{} frame:{}'.format(sequence,ii)) sample = eval_data_manager.get_image(ii) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() # current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[ii]=current_embedding[0] else: current_embedding=embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics, interaction_num=n_interaction,start_annotated_frame=start_annotated_frame, frame_num=[ii],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(ii) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) ####################################### prev_label = ref_prev_label prev_img = ref_img prev_embedding = ref_frame_embedding ###################################### for ii in range(start_annotated_frame): current_frame_num=start_annotated_frame-1-ii print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num)) sample = eval_data_manager.get_image(current_frame_num) img = sample['img'] img = img.unsqueeze(0) _,_,h,w = img.size() img = img.cuda() #current_embedding= model.extract_feature(img) if first_scribble: current_embedding= model.extract_feature(img) embedding_memory[current_frame_num]=current_embedding[0] else: current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_img =prev_img.cuda() prev_label =prev_label.cuda() tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding, current_embedding,scribble_label,prev_label, normalize_nearest_neighbor_distances=True,use_local_map=True, seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS, global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True) pred_label=torch.argmax(pred_label,dim=1) pred_masks_reverse.append(pred_label.float()) prev_label = pred_label.unsqueeze(0) prev_img = img prev_embedding = current_embedding #### if is_save_image: pred_label_to_save=pred_label.squeeze(0).cpu().numpy() im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P') im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname)<5: imgname='0'+imgname if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))): os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))) im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = torch.cat(pred_masks_reverse,0) sess.submit_masks(final_masks.cpu().numpy()) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) report = sess.get_report() summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))
def test_step(self, weights, parallel=True, is_save_image=True, **cfg): # 1. Construct model. cfg['MODEL'].head.pretrained = '' cfg['MODEL'].head.test_mode = True model = build_model(cfg['MODEL']) if parallel: model = paddle.DataParallel(model) # 2. Construct data. sequence = cfg["video_path"].split('/')[-1].split('.')[0] obj_nums = 1 images, _ = load_video(cfg["video_path"], 480) print("stage1 load_video success") # [195, 389, 238, 47, 244, 374, 175, 399] # .shape: (502, 480, 600, 3) report_save_dir = cfg.get("output_dir", f"./output/{cfg['model_name']}") if not os.path.exists(report_save_dir): os.makedirs(report_save_dir) # Configuration used in the challenges max_nb_interactions = 8 # Maximum number of interactions # Interactive parameters model.eval() state_dicts_ = load(weights)['state_dict'] state_dicts = {} for k, v in state_dicts_.items(): if 'num_batches_tracked' not in k: state_dicts['head.' + k] = v if ('head.' + k) not in model.state_dict().keys(): print(f'pretrained -----{k} -------is not in model') write_dict(state_dicts, 'model_for_infer.txt', **cfg) model.set_state_dict(state_dicts) inter_file = open( os.path.join( cfg.get("output_dir", f"./output/{cfg['model_name']}"), 'inter_file.txt'), 'w') seen_seq = False with paddle.no_grad(): # Get the current iteration scribbles for scribbles, first_scribble in get_scribbles(): t_total = timeit.default_timer() f, h, w = images.shape[:3] if 'prev_label_storage' not in locals().keys(): prev_label_storage = paddle.zeros([f, h, w]) if len(annotated_frames(scribbles)) == 0: final_masks = prev_label_storage # ToDo To AP-kai: save_path传过来了 submit_masks(cfg["save_path"], final_masks.numpy(), images) continue # if no scribbles return, keep masks in previous round start_annotated_frame = annotated_frames(scribbles)[0] pred_masks = [] pred_masks_reverse = [] if first_scribble: # If in the first round, initialize memories n_interaction = 1 eval_global_map_tmp_dic = {} local_map_dics = ({}, {}) total_frame_num = f else: n_interaction += 1 inter_file.write(sequence + ' ' + 'interaction' + str(n_interaction) + ' ' + 'frame' + str(start_annotated_frame) + '\n') if first_scribble: # if in the first round, extract pixel embbedings. if not seen_seq: seen_seq = True inter_turn = 1 embedding_memory = [] places = paddle.set_device('cpu') for imgs in images: if cfg['PIPELINE'].get('test'): imgs = paddle.to_tensor([ build_pipeline(cfg['PIPELINE'].test)({ 'img1': imgs })['img1'] ]) else: imgs = paddle.to_tensor([imgs]) if parallel: for c in model.children(): frame_embedding = c.head.extract_feature( imgs) else: frame_embedding = model.head.extract_feature( imgs) embedding_memory.append(frame_embedding) del frame_embedding embedding_memory = paddle.concat(embedding_memory, 0) _, _, emb_h, emb_w = embedding_memory.shape ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) else: inter_turn += 1 ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) else: ref_frame_embedding = embedding_memory[ start_annotated_frame] ref_frame_embedding = ref_frame_embedding.unsqueeze(0) ######## scribble_masks = scribbles2mask(scribbles, (emb_h, emb_w)) scribble_label = scribble_masks[start_annotated_frame] scribble_sample = {'scribble_label': scribble_label} scribble_sample = ToTensor_manet()(scribble_sample) # print(ref_frame_embedding, ref_frame_embedding.shape) scribble_label = scribble_sample['scribble_label'] scribble_label = scribble_label.unsqueeze(0) model_name = cfg['model_name'] output_dir = cfg.get("output_dir", f"./output/{model_name}") inter_file_path = os.path.join( output_dir, sequence, 'interactive' + str(n_interaction), 'turn' + str(inter_turn)) if is_save_image: ref_scribble_to_show = scribble_label.squeeze().numpy() im_ = Image.fromarray( ref_scribble_to_show.astype('uint8')).convert('P', ) im_.putpalette(_palette) ref_img_name = str(start_annotated_frame) if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im_.save( os.path.join(inter_file_path, 'inter_' + ref_img_name + '.png')) if first_scribble: prev_label = None prev_label_storage = paddle.zeros([f, h, w]) else: prev_label = prev_label_storage[start_annotated_frame] prev_label = prev_label.unsqueeze(0).unsqueeze(0) # check if no scribbles. if not first_scribble and paddle.unique( scribble_label).shape[0] == 1: print( 'not first_scribble and paddle.unique(scribble_label).shape[0] == 1' ) print(paddle.unique(scribble_label)) final_masks = prev_label_storage submit_masks(cfg["save_path"], final_masks.numpy(), images) continue ###inteaction segmentation head if parallel: for c in model.children(): tmp_dic, local_map_dics = c.head.int_seghead( ref_frame_embedding=ref_frame_embedding, ref_scribble_label=scribble_label, prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), frame_num=[start_annotated_frame], first_inter=first_scribble) else: tmp_dic, local_map_dics = model.head.int_seghead( ref_frame_embedding=ref_frame_embedding, ref_scribble_label=scribble_label, prev_round_label=prev_label, global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), frame_num=[start_annotated_frame], first_inter=first_scribble) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks.append(float_(pred_label)) # np.unique(pred_label) # array([0], dtype=int64) prev_label_storage[start_annotated_frame] = float_( pred_label[0]) if is_save_image: # save image pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(start_annotated_frame) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) ####################################### if first_scribble: scribble_label = rough_ROI(scribble_label) ############################## ref_prev_label = pred_label.unsqueeze(0) prev_label = pred_label.unsqueeze(0) prev_embedding = ref_frame_embedding for ii in range(start_annotated_frame + 1, total_frame_num): current_embedding = embedding_memory[ii] current_embedding = current_embedding.unsqueeze(0) prev_label = prev_label if parallel: for c in model.children(): tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[ii], dynamic_seghead=c.head.dynamic_seghead) else: tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[ii], dynamic_seghead=model.head.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks.append(float_(pred_label)) prev_label = pred_label.unsqueeze(0) prev_embedding = current_embedding prev_label_storage[ii] = float_(pred_label[0]) if is_save_image: pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(ii) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) ####################################### prev_label = ref_prev_label prev_embedding = ref_frame_embedding ####### # Propagation <- for ii in range(start_annotated_frame): current_frame_num = start_annotated_frame - 1 - ii current_embedding = embedding_memory[current_frame_num] current_embedding = current_embedding.unsqueeze(0) prev_label = prev_label if parallel: for c in model.children(): tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[current_frame_num], dynamic_seghead=c.head.dynamic_seghead) else: tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead( ref_frame_embedding, prev_embedding, current_embedding, scribble_label, prev_label, normalize_nearest_neighbor_distances=True, use_local_map=True, seq_names=[sequence], gt_ids=paddle.to_tensor([obj_nums]), k_nearest_neighbors=cfg['knns'], global_map_tmp_dic=eval_global_map_tmp_dic, local_map_dics=local_map_dics, interaction_num=n_interaction, start_annotated_frame=start_annotated_frame, frame_num=[current_frame_num], dynamic_seghead=model.head.dynamic_seghead) pred_label = tmp_dic[sequence] pred_label = nn.functional.interpolate(pred_label, size=(h, w), mode='bilinear', align_corners=True) pred_label = paddle.argmax(pred_label, axis=1) pred_masks_reverse.append(float_(pred_label)) prev_label = pred_label.unsqueeze(0) prev_embedding = current_embedding #### prev_label_storage[current_frame_num] = float_( pred_label[0]) ### if is_save_image: pred_label_to_save = pred_label.squeeze(0).numpy() im = Image.fromarray( pred_label_to_save.astype('uint8')).convert('P', ) im.putpalette(_palette) imgname = str(current_frame_num) while len(imgname) < 5: imgname = '0' + imgname if not os.path.exists(inter_file_path): os.makedirs(inter_file_path) im.save(os.path.join(inter_file_path, imgname + '.png')) pred_masks_reverse.reverse() pred_masks_reverse.extend(pred_masks) final_masks = paddle.concat(pred_masks_reverse, 0) submit_masks(cfg["save_path"], final_masks.numpy(), images) t_end = timeit.default_timer() print('Total time for single interaction: ' + str(t_end - t_total)) inter_file.close() return None