Пример #1
0
    def init_ref_frame_dic(self):
        self.ref_frame_dic={}
        scribbles_path=os.path.join(self.db_root_dir,'Scribbles')
        for seq in self.seqs:
            selected_json = np.random.choice(['001.json','002.json','003.json'],1)
            selected_json = selected_json[0]
            scribble=os.path.join(self.db_root_dir,'Scribbles',seq,selected_json)
            with open(scribble) as f:
                scribble=json.load(f)
            #    print(scribble)
                scr_frame=annotated_frames(scribble)[0]
                scr_f=str(scr_frame)
                while len(scr_f)!=5:
                    scr_f='0'+scr_f

                ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg')
                ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path))
                h_,w_=ref_tmp.shape[:2]
                scribble_masks=scribbles2mask(scribble,(h_,w_))
                ########################
                ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png')
                ########################

            #    print(scribble_masks)
                scribble_label=scribble_masks[scr_frame]
                self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame}
Пример #2
0
    def update_ref_frame_and_label(self,round_scribble=None,frame_num=None,prev_round_label_dic=None):
        ##########Update reference frame and scribbles
        for seq in self.seqs:
            scribble = round_scribble[seq]
            if frame_num is None:
                scr_frame=annotated_frames(scribble)[0]
            else:
                scr_frame= frame_num[seq]
                scr_frame = int(scr_frame)
            scr_f=str(scr_frame)
            while len(scr_f)!=5:
                scr_f='0'+scr_f
            ref_frame_path=os.path.join('JPEGImages/480p',seq,scr_f+'.jpg')
            #######################
            ref_frame_gt = os.path.join('Annotations/480p/',seq,scr_f+'.png')
            #########################
            ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path))
            h_,w_=ref_tmp.shape[:2]
            scribble_masks=scribbles2mask(scribble,(h_,w_))
            if frame_num is None:

            
                scribble_label=scribble_masks[scr_frame]
            else:
                scribble_label=scribble_masks[0]
            self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame}
            if prev_round_label_dic is not None:
                self.ref_frame_dic[seq]={'ref_frame':ref_frame_path,'scribble_label':scribble_label,'ref_frame_gt':ref_frame_gt,'ref_frame_num':scr_frame,'prev_round_label':prev_round_label_dic[seq]}
Пример #3
0
    def Run_interaction(self, scribbles):
        
        # convert davis scribbles to torch
        target = scribbles['annotated_frame']
        scribble_mask = scribbles2mask(scribbles, (self.height, self.width))[target]
        scribble_mask = Dilate_mask(scribble_mask, 1)
        self.tar_P, self.tar_N = ToCudaPN(scribble_mask)

        self.all_E[:,target], _, self.ref = self.model_I(self.all_F[:,:,target], self.all_E[:,target], self.tar_P, self.tar_N, self.dummy_M, [1,0,0,0,0]) # [batch, 256,512,2]

        print('[MODEL: interaction network] User Interaction on {}'.format(target))    
Пример #4
0
    def Run_interaction(self, scribbles):
        target = scribbles['annotated_frame']
        print('height:{} width:{} target:{}'.format(self.height, self.width, target))
        scribble_mask = scribbles2mask(scribbles, (self.height, self.width))[target]  # 图片本来的宽高
        scribble_mask = Dilate_mask(scribble_mask, 1)

        self.tar_P, self.tar_N = ToCudaPN(scribble_mask)
        with torch.no_grad():
            self.all_E[:, target], _, self.ref = self.model_I(self.all_F[:, :, target], self.all_E[:, target],
                                                              self.tar_P,
                                                              self.tar_N, self.dummy_M,
                                                              [1, 0, 0, 0, 0])  # [batch, 256,512,2]

        print('[MODEL: interaction network] User Interaction on {}'.format(target))
Пример #5
0
    def update_ref_frame_and_label(self,round_scribble):
        ##########Update reference frame and scribbles

        scribble = round_scribble
        scr_frame=annotated_frames(scribble)[0]
        scr_f=str(scr_frame)
        while len(scr_f)!=5:
            scr_f='0'+scr_f

        ref_frame_path=os.path.join('JPEGImages/480p',self.seq_name,scr_f+'.jpg')
        ref_tmp = cv2.imread(os.path.join(self.db_root_dir,ref_frame_path))
        h_,w_=ref_tmp.shape[:2]
        scribble_masks=scribbles2mask(scribble,(h_,w_))
        scribble_label=scribble_masks[scr_frame]
        self.ref_frame_dic={'ref_frame':ref_frame_path,'scribble_label':scribble_label}
Пример #6
0
    def to_mask(self, scribble):
        # First we select the only frame with scribble
        all_scr = scribble['scribbles']
        for idx, s in enumerate(all_scr):
            if len(s) != 0:
                scribble['scribbles'] = [s]
                break

        # Pass to DAVIS to change the path to an array
        scr_mask = scribbles2mask(scribble, (self.h, self.w))[0]

        # Run our S2M
        kernel = np.ones((3, 3), np.uint8)
        mask = torch.zeros((self.k, 1, self.nh, self.nw),
                           dtype=torch.float32,
                           device=self.device)
        for ki in range(1, self.k + 1):
            p_srb = (scr_mask == ki).astype(np.uint8)
            p_srb = cv2.dilate(p_srb, kernel).astype(np.bool)

            n_srb = ((scr_mask != ki) * (scr_mask != -1)).astype(np.uint8)
            n_srb = cv2.dilate(n_srb, kernel).astype(np.bool)

            Rs = torch.from_numpy(np.stack(
                [p_srb, n_srb], 0)).unsqueeze(0).float().to(self.device)
            Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])

            # Use hard mask because we train S2M with such
            inputs = torch.cat([
                self.processor.get_image_buffered(idx),
                (self.processor.masks[idx] == ki).to(
                    self.device).float().unsqueeze(0), Rs
            ], 1)
            mask[ki - 1] = torch.sigmoid(self.s2m_net(inputs))
        mask = aggregate_wbg(mask, keep_bg=True, hard=True)
        return mask, idx
Пример #7
0
    def _init_ref_frame_dic(self):
        self.ref_frame_dic = {}
        scribbles_path = os.path.join(self.db_root_dir, 'Scribbles')
        for seq in self.seqs:
            scribble = os.path.join(self.db_root_dir, 'Scribbles', seq,
                                    '001.json')
            with open(scribble) as f:
                scribble = json.load(f)
                scr_frame = annotated_frames(scribble)[0]
                scr_f = str(scr_frame)
                while len(scr_f) != 5:
                    scr_f = '0' + scr_f

                ref_frame_path = os.path.join('JPEGImages/480p', seq,
                                              scr_f + '.jpg')
                ref_tmp = cv2.imread(
                    os.path.join(self.db_root_dir, ref_frame_path))
                h_, w_ = ref_tmp.shape[:2]
                scribble_masks = scribbles2mask(scribble, (h_, w_))
                scribble_label = scribble_masks[scr_frame]
                self.ref_frame_dic[seq] = {
                    'ref_frame': ref_frame_path,
                    'scribble_label': scribble_label
                }
Пример #8
0
def main():
    total_frame_num_dic={}
    #################
    seqs = []
    with open(os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017', 'val' + '.txt')) as f:
        seqs_tmp = f.readlines()
        seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
        seqs.extend(seqs_tmp)
    for seq_name in seqs:
        images = np.sort(os.listdir(os.path.join(cfg.DATA_ROOT, 'JPEGImages/480p/', seq_name.strip())))
        total_frame_num_dic[seq_name]=len(images)
    _seq_list_file=os.path.join(cfg.DATA_ROOT, 'ImageSets', '2017',
                                          'v_a_l' + '_instances.txt')
    seq_dict = json.load(open(_seq_list_file, 'r'))

    ##################

    is_save_image=True
    report_save_dir= cfg.RESULT_ROOT
    save_res_dir = './saved_model_inter_net_new_re'
        # Configuration used in the challenges
    max_nb_interactions = 8  # Maximum number of interactions
    max_time_per_interaction = 10000  # Maximum time per interaction per object
    # Total time available to interact with a sequence and an initial set of scribbles
    max_time = max_nb_interactions * max_time_per_interaction  # Maximum time per object
    # Interactive parameters
    subset = 'val'
    host = 'localhost'  # 'localhost' for subsets train and val.

    feature_extracter = DeepLab(backbone='resnet',freeze_bn=False)
    model = IntVOS(cfg,feature_extracter)
    model= model.cuda()
    print('model loading...')

    saved_model_dict = os.path.join(save_res_dir,'save_step_80000.pth')
    pretrained_dict = torch.load(saved_model_dict)
    load_network(model,pretrained_dict)

    print('model loading finished!')
    model.eval()

    seen_seq={}
    with torch.no_grad():
        with DavisInteractiveSession(host=host,
                                davis_root=cfg.DATA_ROOT,
                                subset=subset,
                                report_save_dir=report_save_dir,
                                max_nb_interactions=max_nb_interactions,
                                max_time=max_time,
                                metric_to_optimize='J'
                                                            
                                ) as sess:
            while sess.next():
                t_total=timeit.default_timer()
                # Get the current iteration scribbles

                sequence, scribbles, first_scribble = sess.get_scribbles(only_last=True)
                print(sequence)
                start_annotated_frame=annotated_frames(scribbles)[0]
                
                pred_masks=[]
                pred_masks_reverse=[]


                if first_scribble:
                    anno_frame_list=[]
                    n_interaction=1
                    eval_global_map_tmp_dic={}
                    local_map_dics=({},{})
                    total_frame_num=total_frame_num_dic[sequence]
                    obj_nums = seq_dict[sequence][-1]
                    eval_data_manager=DAVIS2017_Test_Manager(split='val',root=cfg.DATA_ROOT,transform=tr.ToTensor(),
                                                            seq_name=sequence)

                else:
                    n_interaction+=1



                ##########################Reference image process
                scr_f = start_annotated_frame
                anno_frame_list.append(start_annotated_frame)
                print(start_annotated_frame)
                scr_f = str(scr_f)
                while len(scr_f)!=5:
                    scr_f='0'+scr_f
                ref_img = os.path.join(cfg.DATA_ROOT,'JPEGImages/480p',sequence,scr_f+'.jpg')
                ref_img = cv2.imread(ref_img)
                h_,w_ = ref_img.shape[:2]
                ref_img = np.array(ref_img,dtype=np.float32)

                #ref_img = tr.ToTensor()(ref_img)
                #ref_img = ref_img.unsqueeze(0)
                scribble_masks=scribbles2mask(scribbles,(h_,w_))
                scribble_label=scribble_masks[start_annotated_frame]
                sample = {'ref_img':ref_img,'scribble_label':scribble_label}
                sample = tr.ToTensor()(sample)
                ref_img = sample['ref_img']
                scribble_label = sample['scribble_label']
                ref_img = ref_img.unsqueeze(0)
                scribble_label =scribble_label.unsqueeze(0)
                ref_img= ref_img.cuda()
                scribble_label = scribble_label.cuda()
                ######
                ref_scribble_to_show = scribble_label.cpu().squeeze().numpy()
                im_ = Image.fromarray(ref_scribble_to_show.astype('uint8')).convert('P')
                im_.putpalette(_palette)
                ref_img_name= scr_f
    
                if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                    os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                im_.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),'inter_'+ref_img_name+'.png'))

                #ref_frame_embedding = model.extract_feature(ref_img)

                if first_scribble:
                    ref_frame_embedding = model.extract_feature(ref_img)
                    _,channel,emb_h,emb_w = ref_frame_embedding.size()
                    embedding_memory=torch.zeros((total_frame_num,channel,emb_h,emb_w))
                    embedding_memory = embedding_memory.cuda()
                    embedding_memory[start_annotated_frame]=ref_frame_embedding[0]
                    
                else:
                    ref_frame_embedding = embedding_memory[start_annotated_frame]
                    ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                ########
                if first_scribble:

                    prev_label= None
                else:
                    prev_label = os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction-1),scr_f+'.png')
                    prev_label = Image.open(prev_label)
                    prev_label = np.array(prev_label,dtype=np.uint8)
                    prev_label = tr.ToTensor()({'label':prev_label})
                    prev_label = prev_label['label'].unsqueeze(0)
                    prev_label = prev_label.cuda()

                ###############
                #tmp_dic, eval_global_map_tmp_dic= model.before_seghead_process(ref_frame_embedding,ref_frame_embedding,
                #                            ref_frame_embedding,scribble_label,prev_label,
                #                            normalize_nearest_neighbor_distances=True,use_local_map=False,
                #                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                #                        global_map_tmp_dic=eval_global_map_tmp_dic,frame_num=[start_annotated_frame],dynamic_seghead=model.dynamic_seghead)
                tmp_dic,local_map_dics = model.int_seghead(ref_frame_embedding=ref_frame_embedding,ref_scribble_label=scribble_label,prev_round_label=prev_label,
                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,
                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),frame_num=[start_annotated_frame],first_inter=first_scribble)
                pred_label = tmp_dic[sequence]
                pred_label = nn.functional.interpolate(pred_label,size=(h_,w_),mode = 'bilinear',align_corners=True)    

                pred_label=torch.argmax(pred_label,dim=1)
                pred_masks.append(pred_label.float())
                    ####
                if is_save_image:
                    pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                    im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                    im.putpalette(_palette)
                    imgname = str(start_annotated_frame)
                    while len(imgname)<5:
                        imgname='0'+imgname
                    if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                        os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                    im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                ##############################
                ref_prev_label = pred_label.unsqueeze(0)
                prev_label = pred_label.unsqueeze(0)
                prev_img = ref_img
                prev_embedding = ref_frame_embedding

                for ii in range(start_annotated_frame+1,total_frame_num):
                    print('evaluating sequence:{} frame:{}'.format(sequence,ii))
                    sample = eval_data_manager.get_image(ii)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                #    current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[ii]=current_embedding[0]
                    else:
                        current_embedding=embedding_memory[ii]
                        current_embedding = current_embedding.unsqueeze(0)


                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,
                                        interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,
                                        frame_num=[ii],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(ii)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                #######################################
                prev_label = ref_prev_label
                prev_img = ref_img
                prev_embedding = ref_frame_embedding
                ######################################
                for ii in range(start_annotated_frame):
                    current_frame_num=start_annotated_frame-1-ii
                    print('evaluating sequence:{} frame:{}'.format(sequence,current_frame_num))
                    sample = eval_data_manager.get_image(current_frame_num)
                    img = sample['img']
                    img = img.unsqueeze(0)
                    _,_,h,w = img.size()
                    img = img.cuda()

                    #current_embedding= model.extract_feature(img)
                    if first_scribble:
                        current_embedding= model.extract_feature(img)
                        embedding_memory[current_frame_num]=current_embedding[0]
                    else:
                        current_embedding = embedding_memory[current_frame_num]
                        current_embedding = current_embedding.unsqueeze(0)
                    prev_img =prev_img.cuda()
                    prev_label =prev_label.cuda()
                    tmp_dic, eval_global_map_tmp_dic,local_map_dics= model.before_seghead_process(ref_frame_embedding,prev_embedding,
                                            current_embedding,scribble_label,prev_label,
                                            normalize_nearest_neighbor_distances=True,use_local_map=True,
                                        seq_names=[sequence],gt_ids=torch.Tensor([obj_nums]),k_nearest_neighbors=cfg.KNNS,
                                        global_map_tmp_dic=eval_global_map_tmp_dic,local_map_dics=local_map_dics,interaction_num=n_interaction,start_annotated_frame=start_annotated_frame,frame_num=[current_frame_num],dynamic_seghead=model.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,size=(h,w),mode = 'bilinear',align_corners=True)    

                    pred_label=torch.argmax(pred_label,dim=1)
                    pred_masks_reverse.append(pred_label.float())
                    prev_label = pred_label.unsqueeze(0)
                    prev_img = img
                    prev_embedding = current_embedding
                    ####
                    if is_save_image:
                        pred_label_to_save=pred_label.squeeze(0).cpu().numpy()
                        im = Image.fromarray(pred_label_to_save.astype('uint8')).convert('P')
                        im.putpalette(_palette)
                        imgname = str(current_frame_num)
                        while len(imgname)<5:
                            imgname='0'+imgname
                        if not os.path.exists(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction))):
                            os.makedirs(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction)))
                        im.save(os.path.join(cfg.RESULT_ROOT,sequence,'interactive'+str(n_interaction),imgname+'.png'))
                pred_masks_reverse.reverse()
                pred_masks_reverse.extend(pred_masks)
                final_masks = torch.cat(pred_masks_reverse,0)
                sess.submit_masks(final_masks.cpu().numpy())
                t_end = timeit.default_timer()
                print('Total time for single interaction: ' + str(t_end - t_total))
            report = sess.get_report()  
            summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))              
Пример #9
0
    def test_step(self, weights, parallel=True, is_save_image=True, **cfg):
        # 1. Construct model.
        cfg['MODEL'].head.pretrained = ''
        cfg['MODEL'].head.test_mode = True
        model = build_model(cfg['MODEL'])
        if parallel:
            model = paddle.DataParallel(model)

        # 2. Construct data.
        sequence = cfg["video_path"].split('/')[-1].split('.')[0]
        obj_nums = 1
        images, _ = load_video(cfg["video_path"], 480)
        print("stage1 load_video success")
        # [195, 389, 238, 47, 244, 374, 175, 399]
        # .shape: (502, 480, 600, 3)
        report_save_dir = cfg.get("output_dir",
                                  f"./output/{cfg['model_name']}")
        if not os.path.exists(report_save_dir):
            os.makedirs(report_save_dir)
            # Configuration used in the challenges
        max_nb_interactions = 8  # Maximum number of interactions
        # Interactive parameters
        model.eval()

        state_dicts_ = load(weights)['state_dict']
        state_dicts = {}
        for k, v in state_dicts_.items():
            if 'num_batches_tracked' not in k:
                state_dicts['head.' + k] = v
                if ('head.' + k) not in model.state_dict().keys():
                    print(f'pretrained -----{k} -------is not in model')
        write_dict(state_dicts, 'model_for_infer.txt', **cfg)
        model.set_state_dict(state_dicts)
        inter_file = open(
            os.path.join(
                cfg.get("output_dir", f"./output/{cfg['model_name']}"),
                'inter_file.txt'), 'w')
        seen_seq = False

        with paddle.no_grad():

            # Get the current iteration scribbles
            for scribbles, first_scribble in get_scribbles():
                t_total = timeit.default_timer()
                f, h, w = images.shape[:3]
                if 'prev_label_storage' not in locals().keys():
                    prev_label_storage = paddle.zeros([f, h, w])
                if len(annotated_frames(scribbles)) == 0:
                    final_masks = prev_label_storage
                    # ToDo To AP-kai: save_path传过来了
                    submit_masks(cfg["save_path"], final_masks.numpy(), images)
                    continue

                # if no scribbles return, keep masks in previous round
                start_annotated_frame = annotated_frames(scribbles)[0]
                pred_masks = []
                pred_masks_reverse = []

                if first_scribble:  # If in the first round, initialize memories
                    n_interaction = 1
                    eval_global_map_tmp_dic = {}
                    local_map_dics = ({}, {})
                    total_frame_num = f

                else:
                    n_interaction += 1
                inter_file.write(sequence + ' ' + 'interaction' +
                                 str(n_interaction) + ' ' + 'frame' +
                                 str(start_annotated_frame) + '\n')

                if first_scribble:  # if in the first round, extract pixel embbedings.
                    if not seen_seq:
                        seen_seq = True
                        inter_turn = 1
                        embedding_memory = []
                        places = paddle.set_device('cpu')

                        for imgs in images:
                            if cfg['PIPELINE'].get('test'):
                                imgs = paddle.to_tensor([
                                    build_pipeline(cfg['PIPELINE'].test)({
                                        'img1':
                                        imgs
                                    })['img1']
                                ])
                            else:
                                imgs = paddle.to_tensor([imgs])
                            if parallel:
                                for c in model.children():
                                    frame_embedding = c.head.extract_feature(
                                        imgs)
                            else:
                                frame_embedding = model.head.extract_feature(
                                    imgs)
                            embedding_memory.append(frame_embedding)

                        del frame_embedding

                        embedding_memory = paddle.concat(embedding_memory, 0)
                        _, _, emb_h, emb_w = embedding_memory.shape
                        ref_frame_embedding = embedding_memory[
                            start_annotated_frame]
                        ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                    else:
                        inter_turn += 1
                        ref_frame_embedding = embedding_memory[
                            start_annotated_frame]
                        ref_frame_embedding = ref_frame_embedding.unsqueeze(0)

                else:
                    ref_frame_embedding = embedding_memory[
                        start_annotated_frame]
                    ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                ########
                scribble_masks = scribbles2mask(scribbles, (emb_h, emb_w))
                scribble_label = scribble_masks[start_annotated_frame]
                scribble_sample = {'scribble_label': scribble_label}
                scribble_sample = ToTensor_manet()(scribble_sample)
                #                     print(ref_frame_embedding, ref_frame_embedding.shape)
                scribble_label = scribble_sample['scribble_label']

                scribble_label = scribble_label.unsqueeze(0)
                model_name = cfg['model_name']
                output_dir = cfg.get("output_dir", f"./output/{model_name}")
                inter_file_path = os.path.join(
                    output_dir, sequence, 'interactive' + str(n_interaction),
                    'turn' + str(inter_turn))
                if is_save_image:
                    ref_scribble_to_show = scribble_label.squeeze().numpy()
                    im_ = Image.fromarray(
                        ref_scribble_to_show.astype('uint8')).convert('P', )
                    im_.putpalette(_palette)
                    ref_img_name = str(start_annotated_frame)

                    if not os.path.exists(inter_file_path):
                        os.makedirs(inter_file_path)
                    im_.save(
                        os.path.join(inter_file_path,
                                     'inter_' + ref_img_name + '.png'))
                if first_scribble:
                    prev_label = None
                    prev_label_storage = paddle.zeros([f, h, w])
                else:
                    prev_label = prev_label_storage[start_annotated_frame]
                    prev_label = prev_label.unsqueeze(0).unsqueeze(0)
                # check if no scribbles.
                if not first_scribble and paddle.unique(
                        scribble_label).shape[0] == 1:
                    print(
                        'not first_scribble and paddle.unique(scribble_label).shape[0] == 1'
                    )
                    print(paddle.unique(scribble_label))
                    final_masks = prev_label_storage
                    submit_masks(cfg["save_path"], final_masks.numpy(), images)
                    continue

                ###inteaction segmentation head
                if parallel:
                    for c in model.children():
                        tmp_dic, local_map_dics = c.head.int_seghead(
                            ref_frame_embedding=ref_frame_embedding,
                            ref_scribble_label=scribble_label,
                            prev_round_label=prev_label,
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            frame_num=[start_annotated_frame],
                            first_inter=first_scribble)
                else:
                    tmp_dic, local_map_dics = model.head.int_seghead(
                        ref_frame_embedding=ref_frame_embedding,
                        ref_scribble_label=scribble_label,
                        prev_round_label=prev_label,
                        global_map_tmp_dic=eval_global_map_tmp_dic,
                        local_map_dics=local_map_dics,
                        interaction_num=n_interaction,
                        seq_names=[sequence],
                        gt_ids=paddle.to_tensor([obj_nums]),
                        frame_num=[start_annotated_frame],
                        first_inter=first_scribble)
                pred_label = tmp_dic[sequence]
                pred_label = nn.functional.interpolate(pred_label,
                                                       size=(h, w),
                                                       mode='bilinear',
                                                       align_corners=True)
                pred_label = paddle.argmax(pred_label, axis=1)
                pred_masks.append(float_(pred_label))
                # np.unique(pred_label)
                # array([0], dtype=int64)
                prev_label_storage[start_annotated_frame] = float_(
                    pred_label[0])

                if is_save_image:  # save image
                    pred_label_to_save = pred_label.squeeze(0).numpy()
                    im = Image.fromarray(
                        pred_label_to_save.astype('uint8')).convert('P', )
                    im.putpalette(_palette)
                    imgname = str(start_annotated_frame)
                    while len(imgname) < 5:
                        imgname = '0' + imgname
                    if not os.path.exists(inter_file_path):
                        os.makedirs(inter_file_path)
                    im.save(os.path.join(inter_file_path, imgname + '.png'))
                #######################################
                if first_scribble:
                    scribble_label = rough_ROI(scribble_label)

                ##############################
                ref_prev_label = pred_label.unsqueeze(0)
                prev_label = pred_label.unsqueeze(0)
                prev_embedding = ref_frame_embedding
                for ii in range(start_annotated_frame + 1, total_frame_num):
                    current_embedding = embedding_memory[ii]
                    current_embedding = current_embedding.unsqueeze(0)
                    prev_label = prev_label
                    if parallel:
                        for c in model.children():
                            tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead(
                                ref_frame_embedding,
                                prev_embedding,
                                current_embedding,
                                scribble_label,
                                prev_label,
                                normalize_nearest_neighbor_distances=True,
                                use_local_map=True,
                                seq_names=[sequence],
                                gt_ids=paddle.to_tensor([obj_nums]),
                                k_nearest_neighbors=cfg['knns'],
                                global_map_tmp_dic=eval_global_map_tmp_dic,
                                local_map_dics=local_map_dics,
                                interaction_num=n_interaction,
                                start_annotated_frame=start_annotated_frame,
                                frame_num=[ii],
                                dynamic_seghead=c.head.dynamic_seghead)
                    else:
                        tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead(
                            ref_frame_embedding,
                            prev_embedding,
                            current_embedding,
                            scribble_label,
                            prev_label,
                            normalize_nearest_neighbor_distances=True,
                            use_local_map=True,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            k_nearest_neighbors=cfg['knns'],
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            start_annotated_frame=start_annotated_frame,
                            frame_num=[ii],
                            dynamic_seghead=model.head.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,
                                                           size=(h, w),
                                                           mode='bilinear',
                                                           align_corners=True)
                    pred_label = paddle.argmax(pred_label, axis=1)
                    pred_masks.append(float_(pred_label))
                    prev_label = pred_label.unsqueeze(0)
                    prev_embedding = current_embedding
                    prev_label_storage[ii] = float_(pred_label[0])
                    if is_save_image:
                        pred_label_to_save = pred_label.squeeze(0).numpy()
                        im = Image.fromarray(
                            pred_label_to_save.astype('uint8')).convert('P', )
                        im.putpalette(_palette)
                        imgname = str(ii)
                        while len(imgname) < 5:
                            imgname = '0' + imgname
                        if not os.path.exists(inter_file_path):
                            os.makedirs(inter_file_path)
                        im.save(os.path.join(inter_file_path,
                                             imgname + '.png'))
                #######################################
                prev_label = ref_prev_label
                prev_embedding = ref_frame_embedding
                #######
                # Propagation <-
                for ii in range(start_annotated_frame):
                    current_frame_num = start_annotated_frame - 1 - ii
                    current_embedding = embedding_memory[current_frame_num]
                    current_embedding = current_embedding.unsqueeze(0)
                    prev_label = prev_label
                    if parallel:
                        for c in model.children():
                            tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead(
                                ref_frame_embedding,
                                prev_embedding,
                                current_embedding,
                                scribble_label,
                                prev_label,
                                normalize_nearest_neighbor_distances=True,
                                use_local_map=True,
                                seq_names=[sequence],
                                gt_ids=paddle.to_tensor([obj_nums]),
                                k_nearest_neighbors=cfg['knns'],
                                global_map_tmp_dic=eval_global_map_tmp_dic,
                                local_map_dics=local_map_dics,
                                interaction_num=n_interaction,
                                start_annotated_frame=start_annotated_frame,
                                frame_num=[current_frame_num],
                                dynamic_seghead=c.head.dynamic_seghead)
                    else:
                        tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead(
                            ref_frame_embedding,
                            prev_embedding,
                            current_embedding,
                            scribble_label,
                            prev_label,
                            normalize_nearest_neighbor_distances=True,
                            use_local_map=True,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            k_nearest_neighbors=cfg['knns'],
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            start_annotated_frame=start_annotated_frame,
                            frame_num=[current_frame_num],
                            dynamic_seghead=model.head.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,
                                                           size=(h, w),
                                                           mode='bilinear',
                                                           align_corners=True)

                    pred_label = paddle.argmax(pred_label, axis=1)
                    pred_masks_reverse.append(float_(pred_label))
                    prev_label = pred_label.unsqueeze(0)
                    prev_embedding = current_embedding
                    ####
                    prev_label_storage[current_frame_num] = float_(
                        pred_label[0])
                    ###
                    if is_save_image:
                        pred_label_to_save = pred_label.squeeze(0).numpy()
                        im = Image.fromarray(
                            pred_label_to_save.astype('uint8')).convert('P', )
                        im.putpalette(_palette)
                        imgname = str(current_frame_num)
                        while len(imgname) < 5:
                            imgname = '0' + imgname
                        if not os.path.exists(inter_file_path):
                            os.makedirs(inter_file_path)
                        im.save(os.path.join(inter_file_path,
                                             imgname + '.png'))
                pred_masks_reverse.reverse()
                pred_masks_reverse.extend(pred_masks)
                final_masks = paddle.concat(pred_masks_reverse, 0)
                submit_masks(cfg["save_path"], final_masks.numpy(), images)

                t_end = timeit.default_timer()
                print('Total time for single interaction: ' +
                      str(t_end - t_total))
        inter_file.close()
        return None
Пример #10
0
    def Run(self, variables):
        all_F = variables['all_F']
        num_objects = variables['info']['num_objs']
        num_frames = variables['info']['num_frames']
        height = variables['info']['height']
        width = variables['info']['width']
        prev_targets = variables['prev_targets']
        scribbles = variables['scribbles']
        target = scribbles['annotated_frame']

        loss = 0
        masks = torch.zeros(num_objects, num_frames, height, width)
        for n_obj in range(1, num_objects + 1):

            # variables for current obj
            all_E_n = variables['mask_objs'][n_obj-1:n_obj].data if variables['mask_objs'][n_obj-1:n_obj] is not None \
                else variables['mask_objs'][n_obj-1:n_obj]
            a_ref = variables['ref'][n_obj - 1]
            prev_E_n = all_E_n.clone()
            all_M_n = (variables['all_M'] == n_obj).long()

            # divide scribbles for current object
            n_scribble = copy.deepcopy(scribbles)
            n_scribble['scribbles'][target] = []
            for p in scribbles['scribbles'][target]:
                if p['object_id'] == n_obj:
                    n_scribble['scribbles'][target].append(copy.deepcopy(p))
                    n_scribble['scribbles'][target][-1]['object_id'] = 1
                else:
                    if p['object_id'] == 0:
                        n_scribble['scribbles'][target].append(
                            copy.deepcopy(p))

            scribble_mask = scribbles2mask(n_scribble, (height, width))[target]
            scribble_mask_N = (prev_E_n[0, target].cpu() >
                               0.5) & (torch.tensor(scribble_mask) == 0)
            scribble_mask[scribble_mask == 0] = -1
            scribble_mask[scribble_mask_N] = 0
            scribble_mask = Dilate_mask(scribble_mask, 1)

            # interaction
            tar_P, tar_N = ToCudaPN(scribble_mask)
            all_E_n[:, target], batch_CE, ref = self.model_I(
                all_F[:, :, target], all_E_n[:, target], tar_P, tar_N,
                all_M_n[:, target], [1, 1, 1, 1, 1])  # [batch, 256,512,2]
            loss += batch_CE

            # propagation
            left_end, right_end, weight = Get_weight(target,
                                                     prev_targets,
                                                     num_frames,
                                                     at_least=-1)

            # Prop_forward
            next_a_ref = None
            for n in range(target + 1, right_end + 1):  #[1,2,...,N-1]
                all_E_n[:, n], batch_CE, next_a_ref = self.model_P(
                    ref, a_ref, all_F[:, :, n], prev_E_n[:, n],
                    all_E_n[:, n - 1], all_M_n[:,
                                               n], [1, 1, 1, 1, 1], next_a_ref)
                loss += batch_CE

            # Prop_backward
            for n in reversed(range(left_end, target)):
                all_E_n[:, n], batch_CE, next_a_ref = self.model_P(
                    ref, a_ref, all_F[:, :, n], prev_E_n[:, n],
                    all_E_n[:, n + 1], all_M_n[:,
                                               n], [1, 1, 1, 1, 1], next_a_ref)
                loss += batch_CE

            for f in range(num_frames):
                all_E_n[:, f, :, :] = weight[f] * all_E_n[:, f, :, :] + (
                    1 - weight[f]) * prev_E_n[:, f, :, :]

            masks[n_obj - 1] = all_E_n[0]
            variables['ref'][n_obj - 1] = next_a_ref

        loss /= num_objects

        em = torch.zeros(1, num_objects + 1, num_frames, height,
                         width).to(masks.device)
        em[0, 0, :, :, :] = torch.prod(1 - masks, dim=0)  # bg prob
        em[0, 1:num_objects + 1, :, :] = masks  # obj prob
        em = torch.clamp(em, 1e-7, 1 - 1e-7)
        all_E = torch.log((em / (1 - em)))

        all_E = F.softmax(all_E, dim=1)
        final_masks = all_E[0].max(0)[1].float()

        variables['prev_targets'].append(target)
        variables['masks'] = final_masks
        variables['mask_objs'] = masks
        variables['probs'] = all_E
        variables['loss'] = loss