Ejemplo n.º 1
0
def _nearest_neighbor_features_per_object_in_chunks(reference_embeddings_flat,
                                                    query_embeddings_flat,
                                                    reference_labels_flat,
                                                    ref_obj_ids,
                                                    k_nearest_neighbors,
                                                    n_chunks, **cfg):
    """Calculates the nearest neighbor features per object in chunks to save mem.
    Uses chunking to bound the memory use.
    Args:
    reference_embeddings_flat: Tensor of shape [n, embedding_dim],
      the embedding vectors for the reference frame.
    query_embeddings_flat: Tensor of shape [m, embedding_dim], the embedding
      vectors for the query frames.
    reference_labels_flat: Tensor of shape [n], the class labels of the
      reference frame.
    ref_obj_ids: int tensor of unique object ids in the reference labels.
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
    n_chunks: Integer, the number of chunks to use to save memory
      (set to 1 for no chunking).
    Returns:
    nn_features: A float32 tensor of nearest neighbor features of shape
      [m, n_objects, feature_dim].
    """

    # reference_embeddings_flat = reference_embeddings_flat.cpu()
    # query_embeddings_flat = query_embeddings_flat.cpu()
    # reference_labels_flat = reference_labels_flat.cpu()
    # ref_obj_ids = ref_obj_ids.cpu()

    chunk_size = int_(
        np.ceil((float_(query_embeddings_flat.shape[0]) / n_chunks).numpy()))
    if cfg.get('test_mode'):
        reference_labels_flat, reference_embeddings_flat = _selected_pixel(
            reference_labels_flat, reference_embeddings_flat)
    wrong_label_mask = (reference_labels_flat != paddle.unsqueeze(
        ref_obj_ids, 1))
    all_features = []
    for n in range(n_chunks):
        if n == 0:
            ys = None
        if n_chunks == 1:
            query_embeddings_flat_chunk = query_embeddings_flat
        else:
            chunk_start = n * chunk_size
            chunk_end = (n + 1) * chunk_size
            query_embeddings_flat_chunk = query_embeddings_flat[
                chunk_start:chunk_end]
        features, ys = _nn_features_per_object_for_chunk(
            reference_embeddings_flat, query_embeddings_flat_chunk,
            wrong_label_mask, k_nearest_neighbors, ys)
        all_features.append(features)
    if n_chunks == 1:
        nn_features = all_features[0]
    else:
        nn_features = paddle.concat(all_features, axis=0)
    return nn_features
Ejemplo n.º 2
0
def _nn_features_per_object_for_chunk(reference_embeddings, query_embeddings,
                                      wrong_label_mask, k_nearest_neighbors,
                                      ys):
    """Extracts features for each object using nearest neighbor attention.
  Args:
    reference_embeddings: Tensor of shape [n_chunk, embedding_dim],
      the embedding vectors for the reference frame.
    query_embeddings: Tensor of shape [m_chunk, embedding_dim], the embedding
      vectors for the query frames.
    wrong_label_mask:
    k_nearest_neighbors: Integer, the number of nearest neighbors to use.
  Returns:
    nn_features: A float32 tensor of nearest neighbor features of shape
      [m_chunk, n_objects, feature_dim].
    """
    #    reference_embeddings_key = reference_embeddings
    #    query_embeddings_key = query_embeddings
    dists, ys = _flattened_pairwise_distances(reference_embeddings,
                                              query_embeddings, ys)

    dists = (paddle.unsqueeze(dists, 1) +
             paddle.unsqueeze(float_(wrong_label_mask), 0) *
             WRONG_LABEL_PADDING_DISTANCE)
    if k_nearest_neighbors == 1:
        features = paddle.min(dists, 2, keepdim=True)
    else:
        dists, _ = paddle.topk(-dists, k=k_nearest_neighbors, axis=2)
        dists = -dists
        valid_mask = (dists < WRONG_LABEL_PADDING_DISTANCE)
        masked_dists = dists * valid_mask.float()
        pad_dist = paddle.max(masked_dists, axis=2, keepdim=True)[0].tile(
            (1, 1, masked_dists.shape[-1]))
        dists = paddle.where(valid_mask, dists, pad_dist)
        # take mean of distances
        features = paddle.mean(dists, axis=2, keepdim=True)

    return features, ys
Ejemplo n.º 3
0
    def int_seghead(self,
                    ref_frame_embedding=None,
                    ref_scribble_label=None,
                    prev_round_label=None,
                    normalize_nearest_neighbor_distances=True,
                    global_map_tmp_dic=None,
                    local_map_dics=None,
                    interaction_num=None,
                    seq_names=None,
                    gt_ids=None,
                    k_nearest_neighbors=1,
                    frame_num=None,
                    first_inter=True):
        dic_tmp = {}
        bs, c, h, w = ref_frame_embedding.shape
        scale_ref_scribble_label = paddle.nn.functional.interpolate(
            float_(ref_scribble_label), size=(h, w), mode='nearest')
        scale_ref_scribble_label = int_(scale_ref_scribble_label)
        if not first_inter:
            scale_prev_round_label = paddle.nn.functional.interpolate(
                float_(prev_round_label), size=(h, w), mode='nearest')
            scale_prev_round_label = int_(scale_prev_round_label)
        n_chunks = 500
        for n in range(bs):

            gt_id = paddle.arange(0, gt_ids[n] + 1)

            gt_id = int_(gt_id)

            seq_ref_frame_embedding = ref_frame_embedding[n]

            ########################Local dist map
            seq_ref_frame_embedding = paddle.transpose(seq_ref_frame_embedding,
                                                       [1, 2, 0])
            seq_ref_scribble_label = paddle.transpose(
                scale_ref_scribble_label[n], [1, 2, 0])
            nn_features_n = local_previous_frame_nearest_neighbor_features_per_object(
                prev_frame_embedding=seq_ref_frame_embedding,
                query_embedding=seq_ref_frame_embedding,
                prev_frame_labels=seq_ref_scribble_label,
                gt_ids=gt_id,
                max_distance=self.cfg['model_max_local_distance'])

            #######
            ######################Global map update
            if seq_names[n] not in global_map_tmp_dic:
                global_map_tmp_dic[seq_names[n]] = paddle.ones_like(
                    nn_features_n).tile([1000, 1, 1, 1, 1])
            nn_features_n_ = paddle.where(
                nn_features_n <=
                global_map_tmp_dic[seq_names[n]][frame_num[n]].unsqueeze(0),
                nn_features_n,
                global_map_tmp_dic[seq_names[n]][frame_num[n]].unsqueeze(0))

            ###

            ###
            #             print('detach 3')
            # nn_features_n_ = nn_features_n_.detach()
            global_map_tmp_dic[seq_names[n]][
                frame_num[n]] = nn_features_n_.detach()[0]
            ##################Local map update
            if local_map_dics is not None:
                local_map_tmp_dic, local_map_dist_dic = local_map_dics
                if seq_names[n] not in local_map_dist_dic:
                    local_map_dist_dic[seq_names[n]] = paddle.zeros([1000, 9])
                if seq_names[n] not in local_map_tmp_dic:
                    local_map_tmp_dic[seq_names[n]] = paddle.ones_like(
                        nn_features_n).unsqueeze(0).tile([1000, 9, 1, 1, 1, 1])
                local_map_dist_dic[seq_names[n]][frame_num[n]][interaction_num
                                                               - 1] = 0

                local_map_dics = (local_map_tmp_dic, local_map_dist_dic)

            ##################
            to_cat_current_frame_embedding = ref_frame_embedding[n].unsqueeze(
                0).tile((gt_id.shape[0], 1, 1, 1))
            to_cat_nn_feature_n = nn_features_n.squeeze(0).transpose(
                [2, 3, 0, 1])

            to_cat_scribble_mask_to_cat = (
                float_(seq_ref_scribble_label) == float_(gt_id)
            )  # float comparision?
            to_cat_scribble_mask_to_cat = float_(
                to_cat_scribble_mask_to_cat.unsqueeze(-1).transpose(
                    [2, 3, 0, 1]))
            if not first_inter:
                seq_prev_round_label = scale_prev_round_label[n].transpose(
                    [1, 2, 0])

                to_cat_prev_round_to_cat = (
                    float_(seq_prev_round_label) == float_(gt_id)
                )  # float comparision?
                to_cat_prev_round_to_cat = float_(
                    to_cat_prev_round_to_cat.unsqueeze(-1).transpose(
                        [2, 3, 0, 1]))
            else:
                to_cat_prev_round_to_cat = paddle.zeros_like(
                    to_cat_scribble_mask_to_cat)
                to_cat_prev_round_to_cat[0] = 1.

            to_cat = paddle.concat(
                (to_cat_current_frame_embedding, to_cat_scribble_mask_to_cat,
                 to_cat_prev_round_to_cat), 1)

            pred_ = self.inter_seghead(to_cat)
            pred_ = pred_.transpose([1, 0, 2, 3])
            dic_tmp[seq_names[n]] = pred_
        if local_map_dics is None:
            return dic_tmp
        else:
            return dic_tmp, local_map_dics
Ejemplo n.º 4
0
    def prop_seghead(
        self,
        ref_frame_embedding=None,
        previous_frame_embedding=None,
        current_frame_embedding=None,
        ref_scribble_label=None,
        previous_frame_mask=None,
        normalize_nearest_neighbor_distances=True,
        use_local_map=True,
        seq_names=None,
        gt_ids=None,
        k_nearest_neighbors=1,
        global_map_tmp_dic=None,
        local_map_dics=None,
        interaction_num=None,
        start_annotated_frame=None,
        frame_num=None,
        dynamic_seghead=None,
    ):
        """return: feature_embedding,global_match_map,local_match_map,previous_frame_mask"""
        ###############
        cfg = self.cfg
        global_map_tmp_dic = global_map_tmp_dic
        dic_tmp = {}
        bs, c, h, w = current_frame_embedding.shape
        if cfg.get('test_mode'):
            scale_ref_scribble_label = float_(ref_scribble_label)
        else:
            scale_ref_scribble_label = paddle.nn.functional.interpolate(
                float_(ref_scribble_label), size=(h, w), mode='nearest')
        scale_ref_scribble_label = int_(scale_ref_scribble_label)
        scale_previous_frame_label = paddle.nn.functional.interpolate(
            float_(previous_frame_mask), size=(h, w), mode='nearest')
        scale_previous_frame_label = int_(scale_previous_frame_label)
        for n in range(bs):
            seq_current_frame_embedding = current_frame_embedding[n]
            seq_ref_frame_embedding = ref_frame_embedding[n]
            seq_prev_frame_embedding = previous_frame_embedding[n]
            seq_ref_frame_embedding = seq_ref_frame_embedding.transpose(
                [1, 2, 0])
            seq_current_frame_embedding = seq_current_frame_embedding.transpose(
                [1, 2, 0])
            seq_ref_scribble_label = scale_ref_scribble_label[n].transpose(
                [1, 2, 0])
            #########Global Map
            nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object(
                reference_embeddings=seq_ref_frame_embedding,
                query_embeddings=seq_current_frame_embedding,
                reference_labels=seq_ref_scribble_label,
                k_nearest_neighbors=k_nearest_neighbors,
                gt_ids=gt_ids[n],
                n_chunks=10)
            if normalize_nearest_neighbor_distances:
                nn_features_n = (paddle.nn.functional.sigmoid(nn_features_n) -
                                 0.5) * 2

            #             print(nn_features_n)

            ###
            if global_map_tmp_dic is not None:  ###when testing, use global map memory
                if seq_names[n] not in global_map_tmp_dic:
                    global_map_tmp_dic[seq_names[n]] = paddle.ones_like(
                        nn_features_n).tile([1000, 1, 1, 1, 1])
                nn_features_n = paddle.where(
                    nn_features_n <= global_map_tmp_dic[seq_names[n]][
                        frame_num[n]].unsqueeze(0), nn_features_n,
                    global_map_tmp_dic[seq_names[n]][frame_num[n]].unsqueeze(
                        0))

                #                 print('detach 1')
                #                 print(nn_features_n.shape)
                # nn_features_n = nn_features_n.detach()
                global_map_tmp_dic[seq_names[n]][
                    frame_num[n]] = nn_features_n.detach()[0]

            #########################Local dist map
            seq_prev_frame_embedding = seq_prev_frame_embedding.transpose(
                [1, 2, 0])
            seq_previous_frame_label = scale_previous_frame_label[n].transpose(
                [1, 2, 0])

            if use_local_map:
                prev_frame_nn_features_n = local_previous_frame_nearest_neighbor_features_per_object(
                    prev_frame_embedding=seq_prev_frame_embedding,
                    query_embedding=seq_current_frame_embedding,
                    prev_frame_labels=seq_previous_frame_label,
                    gt_ids=ref_obj_ids,
                    max_distance=cfg['model_max_local_distance'])
            else:
                prev_frame_nn_features_n, _ = nearest_neighbor_features_per_object(
                    reference_embeddings=seq_prev_frame_embedding,
                    query_embeddings=seq_current_frame_embedding,
                    reference_labels=seq_previous_frame_label,
                    k_nearest_neighbors=k_nearest_neighbors,
                    gt_ids=gt_ids[n],
                    n_chunks=20)
                prev_frame_nn_features_n = (
                    paddle.nn.functional.sigmoid(prev_frame_nn_features_n) -
                    0.5) * 2

            #             print(prev_frame_nn_features_n.mean().item(), prev_frame_nn_features_n.shape, interaction_num)  # o
            #############
            if local_map_dics is not None:  ##When testing, use local map memory
                local_map_tmp_dic, local_map_dist_dic = local_map_dics
                if seq_names[n] not in local_map_dist_dic:
                    print(seq_names[n], 'not in local_map_dist_dic')
                    local_map_dist_dic[seq_names[n]] = paddle.zeros(1000, 9)
                if seq_names[n] not in local_map_tmp_dic:
                    print(seq_names[n], 'not in local_map_tmp_dic')
                    local_map_tmp_dic[seq_names[n]] = paddle.zeros_like(
                        prev_frame_nn_features_n).unsqueeze(0).tile(
                            [1000, 9, 1, 1, 1, 1])
                #                 print(local_map_dist_dic[seq_names[n]].shape)
                #                 print('detach 2')
                # prev_frame_nn_features_n = prev_frame_nn_features_n.detach()
                local_map_dist_dic[seq_names[n]][
                    frame_num[n], interaction_num -
                    1] = 1.0 / (abs(frame_num[n] - start_annotated_frame)
                                )  # bugs fixed.
                local_map_tmp_dic[seq_names[n]][
                    frame_num[n],
                    interaction_num - 1] = prev_frame_nn_features_n.squeeze(
                        0).detach()  # bugs fixed.
                if interaction_num == 1:
                    prev_frame_nn_features_n = local_map_tmp_dic[seq_names[n]][
                        frame_num[n]][interaction_num - 1]
                    prev_frame_nn_features_n = prev_frame_nn_features_n.unsqueeze(
                        0)
                else:
                    if local_map_dist_dic[seq_names[n]][frame_num[n]][interaction_num - 1] > \
                            local_map_dist_dic[seq_names[n]][frame_num[n]][interaction_num - 2]:
                        prev_frame_nn_features_n = local_map_tmp_dic[
                            seq_names[n]][frame_num[n]][interaction_num - 1]
                        prev_frame_nn_features_n = prev_frame_nn_features_n.unsqueeze(
                            0)
                    else:
                        prev_frame_nn_features_n = local_map_tmp_dic[
                            seq_names[n]][frame_num[n]][interaction_num - 2]
                        prev_frame_nn_features_n = prev_frame_nn_features_n.unsqueeze(
                            0)

                local_map_dics = (local_map_tmp_dic, local_map_dist_dic)

            to_cat_previous_frame = (
                float_(seq_previous_frame_label) == float_(ref_obj_ids)
            )  # float comparision?

            to_cat_current_frame_embedding = current_frame_embedding[
                n].unsqueeze(0).tile((ref_obj_ids.shape[0], 1, 1, 1))

            to_cat_nn_feature_n = nn_features_n.squeeze(0).transpose(
                [2, 3, 0, 1])
            to_cat_previous_frame = float_(
                to_cat_previous_frame.unsqueeze(-1).transpose([2, 3, 0, 1]))
            to_cat_prev_frame_nn_feature_n = prev_frame_nn_features_n.squeeze(
                0).transpose([2, 3, 0, 1])
            to_cat = paddle.concat(
                (to_cat_current_frame_embedding, to_cat_nn_feature_n,
                 to_cat_prev_frame_nn_feature_n, to_cat_previous_frame), 1)
            pred_ = dynamic_seghead(to_cat)
            pred_ = pred_.transpose([1, 0, 2, 3])
            dic_tmp[seq_names[n]] = pred_

        if global_map_tmp_dic is None:
            return dic_tmp
        else:
            if local_map_dics is None:
                return dic_tmp, global_map_tmp_dic
            else:
                return dic_tmp, global_map_tmp_dic, local_map_dics
Ejemplo n.º 5
0
def local_previous_frame_nearest_neighbor_features_per_object(
        prev_frame_embedding,
        query_embedding,
        prev_frame_labels,
        gt_ids,
        max_distance=12):
    """Computes nearest neighbor features while only allowing local matches.
  Args:
    prev_frame_embedding: Tensor of shape [height, width, embedding_dim],
      the embedding vectors for the last frame.
    query_embedding: Tensor of shape [height, width, embedding_dim],
      the embedding vectors for the query frames.
    prev_frame_labels: Tensor of shape [height, width, 1], the class labels of
      the previous frame.
    gt_ids: Int Tensor of shape [n_objs] of the sorted unique ground truth
      ids in the first frame.
    max_distance: Integer, the maximum distance allowed for local matching.
  Returns:
    nn_features: A float32 np.array of nearest neighbor features of shape
      [1, height, width, n_objects, 1].
    """
    #     print(query_embedding.shape, prev_frame_embedding.shape)
    #     print(query_embedding.place, prev_frame_embedding.place)
    #     query_embedding = query_embedding.cpu()
    #     prev_frame_embedding = prev_frame_embedding.cpu()
    #     prev_frame_labels = prev_frame_labels.cpu()
    #     print(prev_frame_labels.place, prev_frame_embedding.place, query_embedding.place)

    d = local_pairwise_distances2(query_embedding,
                                  prev_frame_embedding,
                                  max_distance=max_distance)
    height, width = prev_frame_embedding.shape[:2]

    if MODEL_UNFOLD:

        labels = float_(prev_frame_labels).transpose([2, 0, 1]).unsqueeze(0)
        padded_labels = F.pad(labels, (
            2 * max_distance,
            2 * max_distance,
            2 * max_distance,
            2 * max_distance,
        ))
        offset_labels = F.unfold(padded_labels,
                                 kernel_sizes=[height, width],
                                 strides=[2,
                                          2]).reshape([height, width, -1, 1])
        offset_masks = paddle.equal(
            offset_labels,
            float_(gt_ids).unsqueeze(0).unsqueeze(0).unsqueeze(0))
    else:

        masks = paddle.equal(prev_frame_labels,
                             gt_ids.unsqueeze(0).unsqueeze(0))
        padded_masks = nn.functional.pad(masks, (
            0,
            0,
            max_distance,
            max_distance,
            max_distance,
            max_distance,
        ))
        offset_masks = []
        for y_start in range(2 * max_distance + 1):
            y_end = y_start + height
            masks_slice = padded_masks[y_start:y_end]
            for x_start in range(2 * max_distance + 1):
                x_end = x_start + width
                offset_mask = masks_slice[:, x_start:x_end]
                offset_masks.append(offset_mask)
        offset_masks = paddle.stack(offset_masks, axis=2)

    d_tiled = d.unsqueeze(-1).tile((1, 1, 1, gt_ids.shape[0]))
    pad = paddle.ones_like(d_tiled)
    d_masked = paddle.where(offset_masks, d_tiled, pad)
    dists = paddle.min(d_masked, axis=2)
    dists = dists.reshape([1, height, width, gt_ids.shape[0], 1])

    return dists
Ejemplo n.º 6
0
    def test_step(self, weights, parallel=True, is_save_image=True, **cfg):
        # 1. Construct model.
        cfg['MODEL'].head.pretrained = ''
        cfg['MODEL'].head.test_mode = True
        model = build_model(cfg['MODEL'])
        if parallel:
            model = paddle.DataParallel(model)

        # 2. Construct data.
        sequence = cfg["video_path"].split('/')[-1].split('.')[0]
        obj_nums = 1
        images, _ = load_video(cfg["video_path"], 480)
        print("stage1 load_video success")
        # [195, 389, 238, 47, 244, 374, 175, 399]
        # .shape: (502, 480, 600, 3)
        report_save_dir = cfg.get("output_dir",
                                  f"./output/{cfg['model_name']}")
        if not os.path.exists(report_save_dir):
            os.makedirs(report_save_dir)
            # Configuration used in the challenges
        max_nb_interactions = 8  # Maximum number of interactions
        # Interactive parameters
        model.eval()

        state_dicts_ = load(weights)['state_dict']
        state_dicts = {}
        for k, v in state_dicts_.items():
            if 'num_batches_tracked' not in k:
                state_dicts['head.' + k] = v
                if ('head.' + k) not in model.state_dict().keys():
                    print(f'pretrained -----{k} -------is not in model')
        write_dict(state_dicts, 'model_for_infer.txt', **cfg)
        model.set_state_dict(state_dicts)
        inter_file = open(
            os.path.join(
                cfg.get("output_dir", f"./output/{cfg['model_name']}"),
                'inter_file.txt'), 'w')
        seen_seq = False

        with paddle.no_grad():

            # Get the current iteration scribbles
            for scribbles, first_scribble in get_scribbles():
                t_total = timeit.default_timer()
                f, h, w = images.shape[:3]
                if 'prev_label_storage' not in locals().keys():
                    prev_label_storage = paddle.zeros([f, h, w])
                if len(annotated_frames(scribbles)) == 0:
                    final_masks = prev_label_storage
                    # ToDo To AP-kai: save_path传过来了
                    submit_masks(cfg["save_path"], final_masks.numpy(), images)
                    continue

                # if no scribbles return, keep masks in previous round
                start_annotated_frame = annotated_frames(scribbles)[0]
                pred_masks = []
                pred_masks_reverse = []

                if first_scribble:  # If in the first round, initialize memories
                    n_interaction = 1
                    eval_global_map_tmp_dic = {}
                    local_map_dics = ({}, {})
                    total_frame_num = f

                else:
                    n_interaction += 1
                inter_file.write(sequence + ' ' + 'interaction' +
                                 str(n_interaction) + ' ' + 'frame' +
                                 str(start_annotated_frame) + '\n')

                if first_scribble:  # if in the first round, extract pixel embbedings.
                    if not seen_seq:
                        seen_seq = True
                        inter_turn = 1
                        embedding_memory = []
                        places = paddle.set_device('cpu')

                        for imgs in images:
                            if cfg['PIPELINE'].get('test'):
                                imgs = paddle.to_tensor([
                                    build_pipeline(cfg['PIPELINE'].test)({
                                        'img1':
                                        imgs
                                    })['img1']
                                ])
                            else:
                                imgs = paddle.to_tensor([imgs])
                            if parallel:
                                for c in model.children():
                                    frame_embedding = c.head.extract_feature(
                                        imgs)
                            else:
                                frame_embedding = model.head.extract_feature(
                                    imgs)
                            embedding_memory.append(frame_embedding)

                        del frame_embedding

                        embedding_memory = paddle.concat(embedding_memory, 0)
                        _, _, emb_h, emb_w = embedding_memory.shape
                        ref_frame_embedding = embedding_memory[
                            start_annotated_frame]
                        ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                    else:
                        inter_turn += 1
                        ref_frame_embedding = embedding_memory[
                            start_annotated_frame]
                        ref_frame_embedding = ref_frame_embedding.unsqueeze(0)

                else:
                    ref_frame_embedding = embedding_memory[
                        start_annotated_frame]
                    ref_frame_embedding = ref_frame_embedding.unsqueeze(0)
                ########
                scribble_masks = scribbles2mask(scribbles, (emb_h, emb_w))
                scribble_label = scribble_masks[start_annotated_frame]
                scribble_sample = {'scribble_label': scribble_label}
                scribble_sample = ToTensor_manet()(scribble_sample)
                #                     print(ref_frame_embedding, ref_frame_embedding.shape)
                scribble_label = scribble_sample['scribble_label']

                scribble_label = scribble_label.unsqueeze(0)
                model_name = cfg['model_name']
                output_dir = cfg.get("output_dir", f"./output/{model_name}")
                inter_file_path = os.path.join(
                    output_dir, sequence, 'interactive' + str(n_interaction),
                    'turn' + str(inter_turn))
                if is_save_image:
                    ref_scribble_to_show = scribble_label.squeeze().numpy()
                    im_ = Image.fromarray(
                        ref_scribble_to_show.astype('uint8')).convert('P', )
                    im_.putpalette(_palette)
                    ref_img_name = str(start_annotated_frame)

                    if not os.path.exists(inter_file_path):
                        os.makedirs(inter_file_path)
                    im_.save(
                        os.path.join(inter_file_path,
                                     'inter_' + ref_img_name + '.png'))
                if first_scribble:
                    prev_label = None
                    prev_label_storage = paddle.zeros([f, h, w])
                else:
                    prev_label = prev_label_storage[start_annotated_frame]
                    prev_label = prev_label.unsqueeze(0).unsqueeze(0)
                # check if no scribbles.
                if not first_scribble and paddle.unique(
                        scribble_label).shape[0] == 1:
                    print(
                        'not first_scribble and paddle.unique(scribble_label).shape[0] == 1'
                    )
                    print(paddle.unique(scribble_label))
                    final_masks = prev_label_storage
                    submit_masks(cfg["save_path"], final_masks.numpy(), images)
                    continue

                ###inteaction segmentation head
                if parallel:
                    for c in model.children():
                        tmp_dic, local_map_dics = c.head.int_seghead(
                            ref_frame_embedding=ref_frame_embedding,
                            ref_scribble_label=scribble_label,
                            prev_round_label=prev_label,
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            frame_num=[start_annotated_frame],
                            first_inter=first_scribble)
                else:
                    tmp_dic, local_map_dics = model.head.int_seghead(
                        ref_frame_embedding=ref_frame_embedding,
                        ref_scribble_label=scribble_label,
                        prev_round_label=prev_label,
                        global_map_tmp_dic=eval_global_map_tmp_dic,
                        local_map_dics=local_map_dics,
                        interaction_num=n_interaction,
                        seq_names=[sequence],
                        gt_ids=paddle.to_tensor([obj_nums]),
                        frame_num=[start_annotated_frame],
                        first_inter=first_scribble)
                pred_label = tmp_dic[sequence]
                pred_label = nn.functional.interpolate(pred_label,
                                                       size=(h, w),
                                                       mode='bilinear',
                                                       align_corners=True)
                pred_label = paddle.argmax(pred_label, axis=1)
                pred_masks.append(float_(pred_label))
                # np.unique(pred_label)
                # array([0], dtype=int64)
                prev_label_storage[start_annotated_frame] = float_(
                    pred_label[0])

                if is_save_image:  # save image
                    pred_label_to_save = pred_label.squeeze(0).numpy()
                    im = Image.fromarray(
                        pred_label_to_save.astype('uint8')).convert('P', )
                    im.putpalette(_palette)
                    imgname = str(start_annotated_frame)
                    while len(imgname) < 5:
                        imgname = '0' + imgname
                    if not os.path.exists(inter_file_path):
                        os.makedirs(inter_file_path)
                    im.save(os.path.join(inter_file_path, imgname + '.png'))
                #######################################
                if first_scribble:
                    scribble_label = rough_ROI(scribble_label)

                ##############################
                ref_prev_label = pred_label.unsqueeze(0)
                prev_label = pred_label.unsqueeze(0)
                prev_embedding = ref_frame_embedding
                for ii in range(start_annotated_frame + 1, total_frame_num):
                    current_embedding = embedding_memory[ii]
                    current_embedding = current_embedding.unsqueeze(0)
                    prev_label = prev_label
                    if parallel:
                        for c in model.children():
                            tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead(
                                ref_frame_embedding,
                                prev_embedding,
                                current_embedding,
                                scribble_label,
                                prev_label,
                                normalize_nearest_neighbor_distances=True,
                                use_local_map=True,
                                seq_names=[sequence],
                                gt_ids=paddle.to_tensor([obj_nums]),
                                k_nearest_neighbors=cfg['knns'],
                                global_map_tmp_dic=eval_global_map_tmp_dic,
                                local_map_dics=local_map_dics,
                                interaction_num=n_interaction,
                                start_annotated_frame=start_annotated_frame,
                                frame_num=[ii],
                                dynamic_seghead=c.head.dynamic_seghead)
                    else:
                        tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead(
                            ref_frame_embedding,
                            prev_embedding,
                            current_embedding,
                            scribble_label,
                            prev_label,
                            normalize_nearest_neighbor_distances=True,
                            use_local_map=True,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            k_nearest_neighbors=cfg['knns'],
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            start_annotated_frame=start_annotated_frame,
                            frame_num=[ii],
                            dynamic_seghead=model.head.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,
                                                           size=(h, w),
                                                           mode='bilinear',
                                                           align_corners=True)
                    pred_label = paddle.argmax(pred_label, axis=1)
                    pred_masks.append(float_(pred_label))
                    prev_label = pred_label.unsqueeze(0)
                    prev_embedding = current_embedding
                    prev_label_storage[ii] = float_(pred_label[0])
                    if is_save_image:
                        pred_label_to_save = pred_label.squeeze(0).numpy()
                        im = Image.fromarray(
                            pred_label_to_save.astype('uint8')).convert('P', )
                        im.putpalette(_palette)
                        imgname = str(ii)
                        while len(imgname) < 5:
                            imgname = '0' + imgname
                        if not os.path.exists(inter_file_path):
                            os.makedirs(inter_file_path)
                        im.save(os.path.join(inter_file_path,
                                             imgname + '.png'))
                #######################################
                prev_label = ref_prev_label
                prev_embedding = ref_frame_embedding
                #######
                # Propagation <-
                for ii in range(start_annotated_frame):
                    current_frame_num = start_annotated_frame - 1 - ii
                    current_embedding = embedding_memory[current_frame_num]
                    current_embedding = current_embedding.unsqueeze(0)
                    prev_label = prev_label
                    if parallel:
                        for c in model.children():
                            tmp_dic, eval_global_map_tmp_dic, local_map_dics = c.head.prop_seghead(
                                ref_frame_embedding,
                                prev_embedding,
                                current_embedding,
                                scribble_label,
                                prev_label,
                                normalize_nearest_neighbor_distances=True,
                                use_local_map=True,
                                seq_names=[sequence],
                                gt_ids=paddle.to_tensor([obj_nums]),
                                k_nearest_neighbors=cfg['knns'],
                                global_map_tmp_dic=eval_global_map_tmp_dic,
                                local_map_dics=local_map_dics,
                                interaction_num=n_interaction,
                                start_annotated_frame=start_annotated_frame,
                                frame_num=[current_frame_num],
                                dynamic_seghead=c.head.dynamic_seghead)
                    else:
                        tmp_dic, eval_global_map_tmp_dic, local_map_dics = model.head.prop_seghead(
                            ref_frame_embedding,
                            prev_embedding,
                            current_embedding,
                            scribble_label,
                            prev_label,
                            normalize_nearest_neighbor_distances=True,
                            use_local_map=True,
                            seq_names=[sequence],
                            gt_ids=paddle.to_tensor([obj_nums]),
                            k_nearest_neighbors=cfg['knns'],
                            global_map_tmp_dic=eval_global_map_tmp_dic,
                            local_map_dics=local_map_dics,
                            interaction_num=n_interaction,
                            start_annotated_frame=start_annotated_frame,
                            frame_num=[current_frame_num],
                            dynamic_seghead=model.head.dynamic_seghead)
                    pred_label = tmp_dic[sequence]
                    pred_label = nn.functional.interpolate(pred_label,
                                                           size=(h, w),
                                                           mode='bilinear',
                                                           align_corners=True)

                    pred_label = paddle.argmax(pred_label, axis=1)
                    pred_masks_reverse.append(float_(pred_label))
                    prev_label = pred_label.unsqueeze(0)
                    prev_embedding = current_embedding
                    ####
                    prev_label_storage[current_frame_num] = float_(
                        pred_label[0])
                    ###
                    if is_save_image:
                        pred_label_to_save = pred_label.squeeze(0).numpy()
                        im = Image.fromarray(
                            pred_label_to_save.astype('uint8')).convert('P', )
                        im.putpalette(_palette)
                        imgname = str(current_frame_num)
                        while len(imgname) < 5:
                            imgname = '0' + imgname
                        if not os.path.exists(inter_file_path):
                            os.makedirs(inter_file_path)
                        im.save(os.path.join(inter_file_path,
                                             imgname + '.png'))
                pred_masks_reverse.reverse()
                pred_masks_reverse.extend(pred_masks)
                final_masks = paddle.concat(pred_masks_reverse, 0)
                submit_masks(cfg["save_path"], final_masks.numpy(), images)

                t_end = timeit.default_timer()
                print('Total time for single interaction: ' +
                      str(t_end - t_total))
        inter_file.close()
        return None