Exemplo n.º 1
0
def validate_full_vcmr(model, val_loader, split, opts, model_opts):
    LOGGER.info("start running full VCMR evaluation on {opts.task} {split} split...")
    model.eval()
    n_ex = 0
    st = time()
    val_log = {}
    has_gt_target = True
    val_vid2idx = val_loader.dataset.vid2idx
    if split in val_vid2idx:
        video2idx_global = val_vid2idx[split]
    else:
        video2idx_global = val_vid2idx
    video_ids = sorted(list(video2idx_global.keys()))
    video2idx_local = {e: i for i, e in enumerate(video_ids)}
    query_data = val_loader.dataset.query_data

    partial_query_data = []
    total_frame_embeddings, total_c_attn_masks = None, None
    video_batch, video_idx = [], []
    max_clip_len = 0
    for video_i, (vid, vidx) in tqdm(enumerate(video2idx_local.items()), desc="Computing Video Embeddings", total=len(video2idx_local)):
        video_item = val_loader.dataset.video_db[vid]
        video_batch.append(video_item)
        video_idx.append(vidx)
        if len(video_batch) == opts.vcmr_eval_video_batch_size or video_i == len(video2idx_local) - 1:
            video_batch = move_to_cuda(video_collate(video_batch))
            # Safeguard fp16
            for k, item in video_batch.items():
                if isinstance(item, torch.Tensor) and item.dtype == torch.float32:
                    video_batch[k] = video_batch[k].to(dtype=next(model.parameters()).dtype)

            curr_frame_embeddings = model.v_encoder(video_batch, 'repr')
            curr_c_attn_masks = video_batch['c_attn_masks']
            curr_clip_len = curr_frame_embeddings.size(-2)
            assert curr_clip_len <= model_opts.max_clip_len

            if total_frame_embeddings is None:
                feat_dim = curr_frame_embeddings.size(-1)
                total_frame_embeddings = torch.zeros(
                    (len(video2idx_local), model_opts.max_clip_len, feat_dim),
                    dtype=curr_frame_embeddings.dtype,
                    device=curr_frame_embeddings.device)
                total_c_attn_masks = torch.zeros(
                    (len(video2idx_local), model_opts.max_clip_len),
                    dtype=curr_c_attn_masks.dtype,
                    device=curr_frame_embeddings.device)

            indices = torch.tensor(video_idx)
            total_frame_embeddings[indices, :curr_clip_len] = curr_frame_embeddings
            total_c_attn_masks[indices, :curr_clip_len] = curr_c_attn_masks
            max_clip_len = max(max_clip_len, curr_clip_len)
            video_batch, video_idx = [], []

    total_frame_embeddings = total_frame_embeddings[:, :max_clip_len, :]
    total_c_attn_masks = total_c_attn_masks[:, :max_clip_len]

    total_c_attn_masks = total_c_attn_masks[:, :max_clip_len]
    svmr_st_probs_total, svmr_ed_probs_total = None, None
    sorted_q2c_indices, sorted_q2c_scores = None, None
    flat_st_ed_sorted_scores, flat_st_ed_scores_sorted_indices = None, None
    total_qids, total_vids = [], []
    for batch in tqdm(val_loader, desc="Computing q2vScores"):
        qids = batch['qids']
        vids = batch['vids']
        targets = batch['targets']
        if has_gt_target and targets.min() < 0:
            has_gt_target = False
            LOGGER.info("No GT annotations provided, only generate predictions")
        del batch['targets'], batch['qids'], batch['vids']  # for the following input

        total_qids.extend(qids)
        total_vids.extend(vids)
        for qid in qids:
            partial_query_data.append(query_data[qid])
        # Safeguard fp16
        for k, item in batch.items():
            if isinstance(item, torch.Tensor) and item.dtype == torch.float32:
                batch[k] = batch[k].to(dtype=next(model.parameters()).dtype)

        # FIXME
        _q2video_scores, _st_probs, _ed_probs = \
            model.get_pred_from_raw_query(total_frame_embeddings, total_c_attn_masks, **batch, cross=True, val_gather_gpus=False)

        _st_probs = F.softmax(_st_probs, dim=-1)
        _ed_probs = F.softmax(_ed_probs, dim=-1)
        n_ex += len(qids)

        if "SVMR" in opts.full_eval_tasks and has_gt_target:
            row_indices = torch.arange(0, len(_st_probs))
            svmr_gt_vidx = torch.tensor([video2idx_local[e] for e in vids])
            svmr_st_probs = _st_probs[row_indices, svmr_gt_vidx].float().cpu().numpy()
            svmr_ed_probs = _ed_probs[row_indices, svmr_gt_vidx].float().cpu().numpy()
            if svmr_st_probs_total is None:
                svmr_st_probs_total = svmr_st_probs
                svmr_ed_probs_total = svmr_ed_probs
            else:
                svmr_st_probs_total = np.concatenate((svmr_st_probs_total, svmr_st_probs), axis=0)
                svmr_ed_probs_total = np.concatenate((svmr_ed_probs_total, svmr_ed_probs), axis=0)

        if "VR" not in opts.full_eval_tasks or _q2video_scores is None:
            continue

        _q2video_scores = _q2video_scores.float()
        # To give more importance to top scores,
        # the higher opt.alpha is the more importance will be given
        q2video_scores = torch.exp(model_opts.q2c_alpha * _q2video_scores)
        _sorted_q2c_scores, _sorted_q2c_indices = torch.topk(q2video_scores, model_opts.max_vcmr_video, dim=1, largest=True)
        if sorted_q2c_indices is None:
            sorted_q2c_indices = _sorted_q2c_indices.cpu().numpy()
            sorted_q2c_scores = _sorted_q2c_scores.cpu().numpy()
        else:
            sorted_q2c_indices = np.concatenate((sorted_q2c_indices, _sorted_q2c_indices.cpu().numpy()), axis=0)
            sorted_q2c_scores = np.concatenate((sorted_q2c_scores, _sorted_q2c_scores.cpu().numpy()), axis=0)

        if "VCMR" not in opts.full_eval_tasks:
            continue

        row_indices = torch.arange(0, len(_st_probs), device=_st_probs.device).unsqueeze(1)
        _st_probs = _st_probs[row_indices, _sorted_q2c_indices]  # (_N_q, max_vcmr_video, L)
        _ed_probs = _ed_probs[row_indices, _sorted_q2c_indices]
        # (_N_q, max_vcmr_video, L, L)

        _st_ed_scores = torch.einsum("qvm,qv,qvn->qvmn", _st_probs, _sorted_q2c_scores, _ed_probs)
        valid_prob_mask = generate_min_max_length_mask(_st_ed_scores.shape, min_l=model_opts.min_pred_l, max_l=model_opts.max_pred_l)
        _st_ed_scores *= torch.from_numpy(valid_prob_mask).to(_st_ed_scores.device)  # invalid location will become zero!
        # sort across the top-max_n_videos videos (by flatten from the 2nd dim)
        # the indices here are local indices, not global indices
        _n_q = _st_ed_scores.shape[0]
        _flat_st_ed_scores = _st_ed_scores.reshape(_n_q, -1)  # (N_q, max_vcmr_video*L*L)
        _flat_st_ed_sorted_scores, _flat_st_ed_scores_sorted_indices = torch.sort(_flat_st_ed_scores, dim=1, descending=True)

        if flat_st_ed_sorted_scores is None:
            flat_st_ed_scores_sorted_indices = _flat_st_ed_scores_sorted_indices[:, :model_opts.max_before_nms].cpu().numpy()
            flat_st_ed_sorted_scores = _flat_st_ed_sorted_scores[:, :model_opts.max_before_nms].cpu().numpy()
        else:
            flat_st_ed_scores_sorted_indices = \
                np.concatenate((flat_st_ed_scores_sorted_indices, _flat_st_ed_scores_sorted_indices[:, :model_opts.max_before_nms].cpu().numpy()), axis=0)
            flat_st_ed_sorted_scores = \
                np.concatenate((flat_st_ed_sorted_scores, _flat_st_ed_sorted_scores[:, :model_opts.max_before_nms].cpu().numpy()), axis=0)

    svmr_res, vr_res, vcmr_res = [], [], []
    if "SVMR" in opts.full_eval_tasks and has_gt_target:
        st_ed_prob_product = np.einsum("bm,bn->bmn", svmr_st_probs_total, svmr_ed_probs_total)  # (N, L, L)
        valid_prob_mask = generate_min_max_length_mask(st_ed_prob_product.shape, min_l=model_opts.min_pred_l, max_l=model_opts.max_pred_l)
        # invalid location will become zero!
        st_ed_prob_product *= valid_prob_mask
        batched_sorted_triples = find_max_triples_from_upper_triangle_product(st_ed_prob_product, top_n=model_opts.max_before_nms, prob_thd=None)
        for svmr_i, (qid, vid) in tqdm(enumerate(zip(total_qids, total_vids)), desc="[SVMR] Loop over queries to generate predictions", total=len(total_qids)):
            vidx = video2idx_global[vid]
            _sorted_triples = batched_sorted_triples[svmr_i]
            # as we redefined ed_idx, which is inside the moment.
            _sorted_triples[:, 1] += 1  # why 1 bias?
            _sorted_triples[:, :2] = (_sorted_triples[:, :2] * model_opts.vfeat_interval)  # frame duration in down sampling
            cur_ranked_predictions = [[vidx, ] + row for row in _sorted_triples.tolist()]
            cur_query_pred = dict(desc_id=int(qid), desc="", predictions=cur_ranked_predictions)
            svmr_res.append(cur_query_pred)

    if "VR" in opts.full_eval_tasks:
        for vr_i, (_sorted_q2c_scores_row, _sorted_q2c_indices_row) in tqdm(
                enumerate(zip(sorted_q2c_scores[:, :100], sorted_q2c_indices[:, :100])),
                desc="[VR] Loop over queries to generate predictions",
                total=len(total_qids)):
            cur_vr_predictions = []
            for v_score, v_meta_idx in zip(_sorted_q2c_scores_row, _sorted_q2c_indices_row):
                video_idx = video2idx_global[video_ids[v_meta_idx]]
                cur_vr_predictions.append([video_idx, 0, 0, float(v_score)])
            cur_query_pred = dict(desc_id=int(total_qids[vr_i]), desc="", predictions=cur_vr_predictions)
            vr_res.append(cur_query_pred)

    if "VCMR" in opts.full_eval_tasks:
        for vcmr_i, (_flat_st_ed_scores_sorted_indices, _flat_st_ed_sorted_scores) in \
                tqdm(enumerate(zip(flat_st_ed_scores_sorted_indices, flat_st_ed_sorted_scores)),
                     desc="[VCMR] Loop over queries to generate predictions",
                     total=len(total_qids)):  # i is query_idx
            # list([video_idx(int), st(float),
            #       ed(float), score(float)])
            video_meta_indices_local, pred_st_indices, pred_ed_indices = \
                np.unravel_index(_flat_st_ed_scores_sorted_indices,
                                 shape=(model_opts.max_vcmr_video, model_opts.max_clip_len, model_opts.max_clip_len))
            # video_meta_indices_local refers to
            # the indices among the top-max_vcmr_video
            # video_meta_indices refers to
            # the indices in all the videos,
            # which is the True indices
            video_meta_indices = sorted_q2c_indices[
                vcmr_i, video_meta_indices_local]

            pred_st_in_seconds = pred_st_indices.astype(
                np.float32) * model_opts.vfeat_interval
            pred_ed_in_seconds = pred_ed_indices.astype(
                np.float32
                ) * model_opts.vfeat_interval + model_opts.vfeat_interval
            cur_vcmr_redictions = []
            for j, (v_meta_idx, v_score) in enumerate(
                    zip(video_meta_indices,
                        _flat_st_ed_sorted_scores)):  # videos
                video_idx = video2idx_global[video_ids[v_meta_idx.item()]]
                cur_vcmr_redictions.append(
                    [video_idx, float(pred_st_in_seconds[j]),
                     float(pred_ed_in_seconds[j]), float(v_score)])

            cur_query_pred = dict(
                desc_id=int(total_qids[vcmr_i]),
                desc="",
                predictions=cur_vcmr_redictions)
            vcmr_res.append(cur_query_pred)

    eval_res = dict(SVMR=svmr_res, VCMR=vcmr_res, VR=vr_res)
    eval_res = {k: v for k, v in eval_res.items() if len(v) != 0}
    eval_res["video2idx"] = video2idx_global
    eval_submission = get_submission_top_n(eval_res, top_n=model_opts.max_after_nms)

    if has_gt_target:
        metrics = eval_retrieval(eval_submission, partial_query_data,
                                 iou_thds=VCMR_IOU_THDS,
                                 match_number=True,
                                 verbose=False,
                                 use_desc_type=model_opts.eval_with_query_type)

        if model_opts.distributed_eval:
            n_ex_per_rank = all_gather_list(n_ex)
            metrics_per_rank = all_gather_list(metrics)
        else:
            n_ex_per_rank = [n_ex]
            metrics_per_rank = [metrics]
        n_ex = sum(n_ex_per_rank)
        val_log = {}
        gathered_metrics = {}
        for task_type, task_metric in metrics.items():
            gathered_metrics[task_type] = {}
            for k in task_metric.keys():
                if k == "desc_type_ratio":
                    continue
                gathered_v = 0
                for idx, n in enumerate(n_ex_per_rank):
                    gathered_v += n*metrics_per_rank[idx][task_type][k]
                gathered_v = gathered_v / n_ex
                gathered_metrics[task_type][k] = gathered_v
                val_log[
                    f'valid_{split}_{task_type}/{task_type}_{k}'] = gathered_v

        if "VCMR" in gathered_metrics:
            LOGGER.info("metrics_no_nms_VCMR \n{}".format(pprint.pformat(gathered_metrics["VCMR"], indent=4)))
        elif "SVMR" in gathered_metrics:
            LOGGER.info("metrics_no_nms_SVMR \n{}".format(pprint.pformat(gathered_metrics["SVMR"], indent=4)))

        if "VR" in gathered_metrics:
            LOGGER.info("metrics_no_nms_VR \n{}".format(pprint.pformat(gathered_metrics["VR"], indent=4)))

        if model_opts.nms_thd != -1:
            LOGGER.info("Performing nms with nms_thd {}".format(model_opts.nms_thd))
            eval_res_after_nms = dict(video2idx=eval_res["video2idx"])
            if "SVMR" in eval_res:
                eval_res_after_nms["SVMR"] =\
                    post_processing_svmr_nms(
                    eval_res["SVMR"], nms_thd=model_opts.nms_thd,
                    max_before_nms=model_opts.max_before_nms,
                    max_after_nms=model_opts.max_after_nms)
            if "VCMR" in eval_res:
                eval_res_after_nms["VCMR"] =\
                    post_processing_vcmr_nms(
                    eval_res["VCMR"], nms_thd=model_opts.nms_thd,
                    max_before_nms=model_opts.max_before_nms,
                    max_after_nms=model_opts.max_after_nms)
            metrics_nms = eval_retrieval(
                eval_res_after_nms, partial_query_data,
                iou_thds=VCMR_IOU_THDS,
                match_number=True,
                verbose=False,
                use_desc_type=model_opts.eval_with_query_type)

            if model_opts.distributed_eval:
                metrics_nms_per_rank = all_gather_list(metrics_nms)
            else:
                metrics_nms_per_rank = [metrics_nms]
            gathered_metrics_nms = {}

            for task_type, task_metric in metrics_nms.items():
                gathered_metrics_nms[task_type] = {}
                for k in task_metric.keys():
                    if k == "desc_type_ratio":
                        continue
                    gathered_v_nms = 0
                    for idx, n in enumerate(n_ex_per_rank):
                        gathered_v_nms += (n * metrics_nms_per_rank[idx][task_type][k])
                    gathered_v_nms = gathered_v_nms / n_ex
                    gathered_metrics_nms[task_type][k] = gathered_v_nms
                    val_log[f'valid_{split}_{task_type}_nms_{model_opts.nms_thd}/{task_type}_{k}'] = gathered_v_nms
            if "VCMR" in gathered_metrics_nms:
                LOGGER.info("metrics_nms_VCMR \n{}".format(pprint.pformat(gathered_metrics_nms["VCMR"], indent=4)))
            elif "SVMR" in gathered_metrics_nms:
                LOGGER.info("metrics_nms_SVMR \n{}".format(pprint.pformat(gathered_metrics_nms["SVMR"], indent=4)))

        tot_time = time()-st
        val_log.update({f'valid/vcmr_{split}_ex_per_s': n_ex/tot_time})
        LOGGER.info(f"validation finished in {int(tot_time)} seconds")
    model.train()
    return val_log, eval_submission
Exemplo n.º 2
0
def validate_full_vr(model, val_loader, split, opts, model_opts):
    LOGGER.info("start running  full VR evaluation"
                f"on {opts.task} {split} split...")
    model.eval()
    n_ex = 0
    st = time()
    val_log = {}
    has_gt_target = True  # MSRVTT test set has annotations
    try:
        video2idx_global = val_loader.dataset.vid2idx[split]
    except Exception:
        video2idx_global = val_loader.dataset.vid2idx
    video_ids = sorted(list(video2idx_global.keys()))
    video2idx_local = {e: i for i, e in enumerate(video_ids)}
    query_data = val_loader.dataset.query_data

    partial_query_data = []
    total_frame_embeddings = None
    video_batch, video_idx = [], []
    max_clip_len = 0
    for video_i, (vid, vidx) in tqdm(enumerate(video2idx_local.items()),
                                     desc="Computing Video Embeddings",
                                     total=len(video2idx_local)):
        video_item = val_loader.dataset.video_db[vid]
        video_batch.append(video_item)
        video_idx.append(vidx)
        if len(video_batch) == opts.vr_eval_video_batch_size or\
                video_i == len(video2idx_local) - 1:
            video_batch = move_to_cuda(video_collate(video_batch))
            # Safeguard fp16
            for k, item in video_batch.items():
                if isinstance(item, torch.Tensor) and\
                        item.dtype == torch.float32:
                    video_batch[k] = video_batch[k].to(
                        dtype=next(model.parameters()).dtype)
            curr_frame_embeddings = model.v_encoder(video_batch, 'repr')
            curr_c_attn_masks = video_batch['c_attn_masks']
            curr_clip_len = curr_frame_embeddings.size(-2)
            assert curr_clip_len <= model_opts.max_clip_len

            if total_frame_embeddings is None:
                feat_dim = curr_frame_embeddings.size(-1)
                total_frame_embeddings = torch.zeros(
                    (len(video2idx_local), model_opts.max_clip_len, feat_dim),
                    dtype=curr_frame_embeddings.dtype,
                    device=curr_frame_embeddings.device)
                total_c_attn_masks = torch.zeros(
                    (len(video2idx_local), model_opts.max_clip_len),
                    dtype=curr_c_attn_masks.dtype,
                    device=curr_frame_embeddings.device)
            indices = torch.LongTensor(video_idx)
            total_frame_embeddings[indices, :curr_clip_len] =\
                curr_frame_embeddings
            total_c_attn_masks[indices, :curr_clip_len] =\
                curr_c_attn_masks
            max_clip_len = max(max_clip_len, curr_clip_len)
            video_batch, video_idx = [], []
    total_frame_embeddings = total_frame_embeddings[:, :max_clip_len, :]
    total_c_attn_masks = total_c_attn_masks[:, :max_clip_len]

    sorted_q2c_indices, sorted_q2c_scores = None, None
    total_qids, total_vids = [], []
    for batch in tqdm(val_loader, desc="Computing q2vScores"):
        qids = batch['qids']
        vids = batch['vids']

        del batch['targets']
        del batch['qids']
        del batch['vids']

        total_qids.extend(qids)
        total_vids.extend(vids)
        for qid in qids:
            # fix msrvtt query data to have tvr format
            gt = query_data[qid]
            gt["desc_id"] = qid
            gt["vid_name"] = gt["clip_name"]
            partial_query_data.append(gt)
        # Safeguard fp16
        for k, item in batch.items():
            if isinstance(item, torch.Tensor) and item.dtype == torch.float32:
                batch[k] = batch[k].to(dtype=next(model.parameters()).dtype)

        # FIXME
        _q2video_scores = model.get_pred_from_raw_query(total_frame_embeddings,
                                                        total_c_attn_masks,
                                                        **batch,
                                                        cross=True,
                                                        val_gather_gpus=False)
        n_ex += len(qids)

        _q2video_scores = _q2video_scores.float()

        q2video_scores = _q2video_scores
        _sorted_q2c_scores, _sorted_q2c_indices = \
            torch.topk(q2video_scores, model_opts.max_vr_video,
                       dim=1, largest=True)
        if sorted_q2c_indices is None:
            sorted_q2c_indices = _sorted_q2c_indices.cpu().numpy()
            sorted_q2c_scores = _sorted_q2c_scores.cpu().numpy()
        else:
            sorted_q2c_indices = np.concatenate(
                (sorted_q2c_indices, _sorted_q2c_indices.cpu().numpy()),
                axis=0)
            sorted_q2c_scores = np.concatenate(
                (sorted_q2c_scores, _sorted_q2c_scores.cpu().numpy()), axis=0)

    vr_res = []
    for vr_i, (_sorted_q2c_scores_row, _sorted_q2c_indices_row) in tqdm(
            enumerate(
                zip(sorted_q2c_scores[:, :100], sorted_q2c_indices[:, :100])),
            desc="[VR] Loop over queries to generate predictions",
            total=len(total_qids)):
        cur_vr_redictions = []
        for v_score, v_meta_idx in zip(_sorted_q2c_scores_row,
                                       _sorted_q2c_indices_row):
            video_idx = video2idx_global[video_ids[v_meta_idx]]
            cur_vr_redictions.append([video_idx, 0, 0, float(v_score)])
        cur_query_pred = dict(desc_id=total_qids[vr_i],
                              desc="",
                              predictions=cur_vr_redictions)
        vr_res.append(cur_query_pred)
    eval_res = dict(VR=vr_res)
    eval_res = {k: v for k, v in eval_res.items() if len(v) != 0}
    eval_res["video2idx"] = video2idx_global

    eval_submission = get_submission_top_n(eval_res,
                                           top_n=model_opts.max_vr_video)

    if has_gt_target:
        metrics = eval_retrieval(eval_submission,
                                 partial_query_data,
                                 iou_thds=VCMR_IOU_THDS,
                                 match_number=True,
                                 verbose=False,
                                 use_desc_type=False)

        if model_opts.distributed_eval:
            n_ex_per_rank = all_gather_list(n_ex)
            metrics_per_rank = all_gather_list(metrics)
        else:
            n_ex_per_rank = [n_ex]
            metrics_per_rank = [metrics]
        n_ex = sum(n_ex_per_rank)
        val_log = {}
        gathered_metrics = {}
        for task_type, task_metric in metrics.items():
            gathered_metrics[task_type] = {}
            for k in task_metric.keys():
                if k == "desc_type_ratio":
                    continue
                gathered_v = 0
                for idx, n in enumerate(n_ex_per_rank):
                    gathered_v += n * metrics_per_rank[idx][task_type][k]
                gathered_v = gathered_v / n_ex
                gathered_metrics[task_type][k] = gathered_v
                val_log[
                    f'valid_{split}_{task_type}/{task_type}_{k}'] = gathered_v

        LOGGER.info("metrics_VR \n{}".format(
            pprint.pformat(gathered_metrics["VR"], indent=4)))

        tot_time = time() - st
        val_log.update({f'valid/vr_{split}_ex_per_s': n_ex / tot_time})
        LOGGER.info(f"validation finished in {int(tot_time)} seconds")
    model.train()
    return val_log, eval_submission