Exemple #1
0
def compute_query2ctx_info_svmr_only(model, eval_dataset, opt, ctx_info,
                                     max_before_nms=1000, max_n_videos=200, tasks=("SVMR",)):
    """Use val set to do evaluation, remember to run with torch.no_grad().
    estimated size 20,000 (query) * 500 (hsz) * 4 / (1024**2) = 38.15 MB
    max_n_videos: int, use max_n_videos videos for computing VCMR results
    """
    model.eval()
    eval_dataset.set_data_mode("query")
    eval_dataset.load_gt_vid_name_for_query(True)
    query_eval_loader = DataLoader(eval_dataset,
                                   collate_fn=start_end_collate,
                                   batch_size=opt.eval_query_bsz,
                                   num_workers=opt.num_workers,
                                   shuffle=False,
                                   pin_memory=opt.pin_memory)
    video2idx = eval_dataset.video2idx
    video_metas = ctx_info["video_metas"]
    n_total_query = len(eval_dataset)
    bsz = opt.eval_query_bsz
    ctx_len = eval_dataset.max_ctx_len  # all pad to this length

    svmr_video2meta_idx = {e["vid_name"]: idx for idx, e in enumerate(video_metas)}
    svmr_gt_st_probs = np.zeros((n_total_query, ctx_len), dtype=np.float32)
    svmr_gt_ed_probs = np.zeros((n_total_query, ctx_len), dtype=np.float32)

    query_metas = []
    for idx, batch in tqdm(
            enumerate(query_eval_loader), desc="Computing q embedding", total=len(query_eval_loader)):
        _query_metas = batch[0]
        query_metas.extend(batch[0])
        model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
        # query_context_scores (_N_q, N_videos), st_prob, ed_prob (_N_q, L)
        query2video_meta_indices = torch.LongTensor([svmr_video2meta_idx[e["vid_name"]] for e in _query_metas])
        _query_context_scores, _st_probs, _ed_probs = \
            model.get_pred_from_raw_query(model_inputs["query_feat"], model_inputs["query_mask"],
                                          index_if_not_none(ctx_info["x_feat"], query2video_meta_indices),
                                          index_if_not_none(ctx_info["x_mask"], query2video_meta_indices),
                                          cross=False)
        _query_context_scores = _query_context_scores + 1  # move cosine similarity to [0, 2]

        # normalize to get true probabilities!!!
        # the probabilities here are already (pad) masked, so only need to do softmax
        _st_probs = F.softmax(_st_probs, dim=-1)  # (_N_q, L)
        _ed_probs = F.softmax(_ed_probs, dim=-1)

        svmr_gt_st_probs[idx * bsz:(idx + 1) * bsz, :_st_probs.shape[1]] = _st_probs.cpu().numpy()
        svmr_gt_ed_probs[idx * bsz:(idx + 1) * bsz, :_ed_probs.shape[1]] = _ed_probs.cpu().numpy()

        if opt.debug:
            break
    svmr_res = get_svmr_res_from_st_ed_probs(svmr_gt_st_probs, svmr_gt_ed_probs,
                                             query_metas, video2idx,
                                             clip_length=opt.clip_length,
                                             min_pred_l=opt.min_pred_l,
                                             max_pred_l=opt.max_pred_l,
                                             max_before_nms=max_before_nms)
    return dict(SVMR=svmr_res)
Exemple #2
0
def compute_context_info(model, eval_dataset, opt):
    """Use val set to do evaluation, remember to run with torch.no_grad().
    estimated 2200 (videos) * 100 (frm) * 500 (hsz) * 4 (B) * 2 (video/sub) * 2 (layers) / (1024 ** 2) = 1.76 GB
    max_n_videos: only consider max_n_videos videos for each query to return st_ed scores.
    """
    model.eval()
    eval_dataset.set_data_mode("context")
    context_dataloader = DataLoader(eval_dataset,
                                    collate_fn=start_end_collate,
                                    batch_size=opt.eval_context_bsz,
                                    num_workers=opt.num_workers,
                                    shuffle=False,
                                    pin_memory=opt.pin_memory)
    metas = []  # list(dicts)
    x_feat = []
    x_mask = []
    for idx, batch in tqdm(enumerate(context_dataloader),
                           desc="Computing query2video scores",
                           total=len(context_dataloader)):
        metas.extend(batch[0])
        model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)

        _x_feat = model.encode_context(
            model_inputs["video_feat"], model_inputs["video_mask"],
            model_inputs["sub_feat"], model_inputs["sub_mask"])
        x_feat.append(_x_feat)
        x_mask.append(model_inputs["video_mask"])  # video_mask == sub_mask

    def cat_tensor(tensor_list):
        if len(tensor_list) == 0:
            return None
        else:
            seq_l = [e.shape[1] for e in tensor_list]
            b_sizes = [e.shape[0] for e in tensor_list]
            b_sizes_cumsum = np.cumsum([0] + b_sizes)
            if len(tensor_list[0].shape) == 3:
                hsz = tensor_list[0].shape[2]
                res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l), hsz)
            elif len(tensor_list[0].shape) == 2:
                res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l))
            else:
                raise ValueError("Only support 2/3 dimensional tensors")
            for i, e in enumerate(tensor_list):
                res_tensor[b_sizes_cumsum[i]:b_sizes_cumsum[i+1], :seq_l[i]] = e
            return res_tensor

    return dict(
        video_metas=metas,  # list(dict) (N_videos)
        x_feat=cat_tensor(x_feat),
        x_mask=cat_tensor(x_mask),
    )
Exemple #3
0
def compute_query2ctx_info(model,
                           eval_dataset,
                           opt,
                           ctx_info,
                           max_before_nms=1000,
                           max_n_videos=100,
                           tasks=("SVMR", )):
    """Use val set to do evaluation, remember to run with torch.no_grad().
    estimated size 20,000 (query) * 500 (hsz) * 4 / (1024**2) = 38.15 MB
    max_n_videos: int, use max_n_videos videos for computing VCMR/VR results
    """
    is_svmr = "SVMR" in tasks
    is_vr = "VR" in tasks
    is_vcmr = "VCMR" in tasks

    video2idx = eval_dataset.video2idx
    video_metas = ctx_info["video_metas"]
    if opt.external_inference_vr_res_path is not None:
        video_idx2meta_idx = {
            video2idx[m["vid_name"]]: i
            for i, m in enumerate(video_metas)
        }
        external_query2video = \
            load_external_vr_res2(opt.external_inference_vr_res_path, top_n_vr_videos=max_n_videos)
        # 「query idx: [video meta idx]」
        external_query2video_meta_idx = \
            {k: [video_idx2meta_idx[e[0]] for e in v] for k, v in external_query2video.items()}
    else:
        external_query2video = None
        external_query2video_meta_idx = None

    model.eval()
    eval_dataset.set_data_mode("query")
    eval_dataset.load_gt_vid_name_for_query(is_svmr)
    query_eval_loader = DataLoader(eval_dataset,
                                   collate_fn=start_end_collate,
                                   batch_size=opt.eval_query_bsz,
                                   num_workers=opt.num_workers,
                                   shuffle=False,
                                   pin_memory=opt.pin_memory)
    n_total_videos = len(video_metas)
    n_total_query = len(eval_dataset)
    bsz = opt.eval_query_bsz

    if is_vcmr:
        flat_st_ed_scores_sorted_indices = np.empty(
            (n_total_query, max_before_nms), dtype=np.int)
        flat_st_ed_sorted_scores = np.zeros((n_total_query, max_before_nms),
                                            dtype=np.float32)

    if is_vr or is_vcmr:
        sorted_q2c_indices = np.empty((n_total_query, max_n_videos),
                                      dtype=np.int)
        sorted_q2c_scores = np.empty((n_total_query, max_n_videos),
                                     dtype=np.float32)

    if is_svmr:
        svmr_video2meta_idx = {
            e["vid_name"]: idx
            for idx, e in enumerate(video_metas)
        }
        svmr_gt_st_probs = np.zeros((n_total_query, opt.max_ctx_l),
                                    dtype=np.float32)
        svmr_gt_ed_probs = np.zeros((n_total_query, opt.max_ctx_l),
                                    dtype=np.float32)

    query_metas = []
    for idx, batch in tqdm(enumerate(query_eval_loader),
                           desc="Computing q embedding",
                           total=len(query_eval_loader)):
        _query_metas = batch[0]
        query_metas.extend(batch[0])
        model_inputs = prepare_batch_inputs(batch[1],
                                            device=opt.device,
                                            non_blocking=opt.pin_memory)
        # query_context_scores (_N_q, N_videos), st_prob, ed_prob (_N_q, N_videos, L)
        _query_context_scores, _st_probs, _ed_probs = \
            model.get_pred_from_raw_query(model_inputs["query_feat"], model_inputs["query_mask"],
                                          ctx_info["video_feat1"], ctx_info["video_feat2"],
                                          ctx_info["video_mask"],
                                          ctx_info["sub_feat1"], ctx_info["sub_feat2"],
                                          ctx_info["sub_mask"],
                                          cross=True)
        # _query_context_scores = _query_context_scores + 1  # move cosine similarity to [0, 2]
        # To give more importance to top scores, the higher opt.alpha is the more importance will be given
        _query_context_scores = torch.exp(opt.q2c_alpha *
                                          _query_context_scores)

        # normalize to get true probabilities!!!
        # the probabilities here are already (pad) masked, so only need to do softmax
        _st_probs = F.softmax(_st_probs, dim=-1)  # (_N_q, N_videos, L)
        _ed_probs = F.softmax(_ed_probs, dim=-1)

        if is_svmr:  # collect SVMR data
            row_indices = torch.arange(0, len(_st_probs))
            query2video_meta_indices = torch.LongTensor(
                [svmr_video2meta_idx[e["vid_name"]] for e in _query_metas])
            # print("svmr_gt_st_probs[idx * bsz:(idx + 1) * bsz, :_st_probs.shape[1]] {}"
            #       .format(svmr_gt_st_probs[idx * bsz:(idx + 1) * bsz, :_st_probs.shape[1]].shape))
            # print("_st_probs[row_indices, query2video_meta_indices] {}"
            #       .format(_st_probs[row_indices, query2video_meta_indices].shape))
            # print("_st_probs {}".format(_st_probs.shape))
            svmr_gt_st_probs[idx * bsz:(idx + 1) * bsz, :_st_probs.shape[2]] = \
                _st_probs[row_indices, query2video_meta_indices].cpu().numpy()
            svmr_gt_ed_probs[idx * bsz:(idx + 1) * bsz, :_ed_probs.shape[2]] = \
                _ed_probs[row_indices, query2video_meta_indices].cpu().numpy()

        if not (is_vr or is_vcmr):
            continue

        # Get top-max_n_videos videos for each query
        # _sorted_q2c_scores, _sorted_q2c_indices = \
        # torch.sort(_query_context_scores, descending=True)  # (_N_q, N_videos)
        # _sorted_q2c_scores = _sorted_q2c_scores[:, :max_n_videos]  # (N_q, max_n_videos)
        # _sorted_q2c_indices = _sorted_q2c_indices[:, :max_n_videos]
        if external_query2video is None:
            _sorted_q2c_scores, _sorted_q2c_indices = \
                torch.topk(_query_context_scores, max_n_videos, dim=1, largest=True)
        else:
            relevant_video_info = [
                external_query2video[qm["desc_id"]] for qm in _query_metas
            ]
            _sorted_q2c_indices = _query_context_scores.new(
                [[video_idx2meta_idx[sub_e[0]] for sub_e in e]
                 for e in relevant_video_info]).long()
            _sorted_q2c_scores = _query_context_scores.new(
                [[sub_e[3] for sub_e in e] for e in relevant_video_info])
            _sorted_q2c_scores = torch.exp(opt.q2c_alpha * _sorted_q2c_scores)
        # collect data for vr and vcmr
        sorted_q2c_indices[idx * bsz:(idx + 1) *
                           bsz] = _sorted_q2c_indices.cpu().numpy()
        sorted_q2c_scores[idx * bsz:(idx + 1) *
                          bsz] = _sorted_q2c_scores.cpu().numpy()

        if not is_vcmr:
            continue

        # Get VCMR results
        # compute combined scores
        row_indices = torch.arange(0, len(_st_probs),
                                   device=opt.device).unsqueeze(1)
        _st_probs = _st_probs[row_indices,
                              _sorted_q2c_indices]  # (_N_q, max_n_videos, L)
        _ed_probs = _ed_probs[row_indices, _sorted_q2c_indices]

        # (_N_q, max_n_videos, L, L)
        _st_ed_scores = torch.einsum("qvm,qv,qvn->qvmn", _st_probs,
                                     _sorted_q2c_scores, _ed_probs)
        valid_prob_mask = generate_min_max_length_mask(_st_ed_scores.shape,
                                                       min_l=opt.min_pred_l,
                                                       max_l=opt.max_pred_l)
        _st_ed_scores *= torch.from_numpy(valid_prob_mask).to(
            _st_ed_scores.device)  # invalid location will become zero!

        # sort across the top-max_n_videos videos (by flatten from the 2nd dim)
        # the indices here are local indices, not global indices
        _n_q = _st_ed_scores.shape[0]
        _flat_st_ed_scores = _st_ed_scores.reshape(
            _n_q, -1)  # (N_q, max_n_videos*L*L)
        _flat_st_ed_sorted_scores, _flat_st_ed_scores_sorted_indices = \
            torch.sort(_flat_st_ed_scores, dim=1, descending=True)
        # collect data
        flat_st_ed_sorted_scores[idx * bsz:(idx + 1) * bsz] = \
            _flat_st_ed_sorted_scores[:, :max_before_nms].cpu().numpy()
        flat_st_ed_scores_sorted_indices[idx * bsz:(idx + 1) * bsz] = \
            _flat_st_ed_scores_sorted_indices[:, :max_before_nms].cpu().numpy()

        if opt.debug:
            break

    # Numpy starts here!!!
    svmr_res = []
    if is_svmr:
        svmr_res = get_svmr_res_from_st_ed_probs(svmr_gt_st_probs,
                                                 svmr_gt_ed_probs,
                                                 query_metas,
                                                 video2idx,
                                                 clip_length=opt.clip_length,
                                                 min_pred_l=opt.min_pred_l,
                                                 max_pred_l=opt.max_pred_l,
                                                 max_before_nms=max_before_nms)

    vr_res = []
    if is_vr:
        for i, (_sorted_q2c_scores_row, _sorted_q2c_indices_row) in tqdm(
                enumerate(
                    zip(sorted_q2c_scores[:, :100],
                        sorted_q2c_indices[:, :100])),
                desc="[VR] Loop over queries to generate predictions",
                total=n_total_query):
            cur_vr_redictions = []
            for j, (v_score, v_meta_idx) in enumerate(
                    zip(_sorted_q2c_scores_row, _sorted_q2c_indices_row)):
                video_idx = video2idx[video_metas[v_meta_idx]["vid_name"]]
                cur_vr_redictions.append([video_idx, 0, 0, float(v_score)])
            cur_query_pred = dict(desc_id=query_metas[i]["desc_id"],
                                  desc=query_metas[i]["desc"],
                                  predictions=cur_vr_redictions)
            vr_res.append(cur_query_pred)

    vcmr_res = []
    if is_vcmr:
        for i, (_flat_st_ed_scores_sorted_indices,
                _flat_st_ed_sorted_scores) in tqdm(
                    enumerate(
                        zip(flat_st_ed_scores_sorted_indices,
                            flat_st_ed_sorted_scores)),
                    desc="[VCMR] Loop over queries to generate predictions",
                    total=n_total_query):  # i is query_idx
            # list([video_idx(int), st(float), ed(float), score(float)])
            video_meta_indices_local, pred_st_indices, pred_ed_indices = \
                np.unravel_index(_flat_st_ed_scores_sorted_indices,
                                 shape=(max_n_videos, opt.max_ctx_l, opt.max_ctx_l))
            # video_meta_indices_local refers to the indices among the top-max_n_videos
            # video_meta_indices refers to the indices in all the videos, which is the True indices
            video_meta_indices = sorted_q2c_indices[i,
                                                    video_meta_indices_local]

            pred_st_in_seconds = pred_st_indices.astype(
                np.float32) * opt.clip_length
            pred_ed_in_seconds = pred_ed_indices.astype(
                np.float32) * opt.clip_length + opt.clip_length
            cur_vcmr_redictions = []
            for j, (v_meta_idx, v_score) in enumerate(
                    zip(video_meta_indices,
                        _flat_st_ed_sorted_scores)):  # videos
                video_idx = video2idx[video_metas[v_meta_idx]["vid_name"]]
                cur_vcmr_redictions.append([
                    video_idx,
                    float(pred_st_in_seconds[j]),
                    float(pred_ed_in_seconds[j]),
                    float(v_score)
                ])

            cur_query_pred = dict(desc_id=query_metas[i]["desc_id"],
                                  desc=query_metas[i]["desc"],
                                  predictions=cur_vcmr_redictions)
            vcmr_res.append(cur_query_pred)

    res = dict(SVMR=svmr_res, VCMR=vcmr_res, VR=vr_res)
    return {k: v for k, v in res.items() if len(v) != 0}
Exemple #4
0
def train_epoch(model, train_loader, optimizer, opt, epoch_i, training=True):
    logger.info("use train_epoch func for training: {}".format(training))
    model.train(mode=training)
    if opt.hard_negtiave_start_epoch != -1 and epoch_i >= opt.hard_negtiave_start_epoch:
        model.set_hard_negative(True, opt.hard_pool_size)
    if opt.train_span_start_epoch != -1 and epoch_i >= opt.train_span_start_epoch:
        model.set_train_st_ed(opt.lw_st_ed)

    # init meters
    dataloading_time = AverageMeter()
    prepare_inputs_time = AverageMeter()
    model_forward_time = AverageMeter()
    model_backward_time = AverageMeter()
    loss_meters = OrderedDict(loss_st_ed=AverageMeter(),
                              loss_neg_ctx=AverageMeter(),
                              loss_neg_q=AverageMeter(),
                              loss_overall=AverageMeter())

    num_training_examples = len(train_loader)
    timer_dataloading = time.time()
    for batch_idx, batch in tqdm(enumerate(train_loader),
                                 desc="Training Iteration",
                                 total=num_training_examples):
        global_step = epoch_i * num_training_examples + batch_idx
        dataloading_time.update(time.time() - timer_dataloading)

        # continue
        timer_start = time.time()
        model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
        prepare_inputs_time.update(time.time() - timer_start)
        # logger.info("model_inputs {}"
        #             .format({k: (type(k), v.shape if isinstance(v, torch.Tensor) else v)
        #                      for k, v in model_inputs.items()}))
        # logger.info("model_inputs \n{}".format({k: (type(v), v.shape, v.dtype) for k, v in model_inputs.items()}))
        timer_start = time.time()
        loss, loss_dict = model(**model_inputs)
        model_forward_time.update(time.time() - timer_start)
        timer_start = time.time()
        if training:
            optimizer.zero_grad()
            loss.backward()
            if opt.grad_clip != -1:
                nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
            model_backward_time.update(time.time() - timer_start)

            opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step)
            for k, v in loss_dict.items():
                opt.writer.add_scalar("Train/{}".format(k), v, global_step)

        for k, v in loss_dict.items():
            loss_meters[k].update(float(v))

        timer_dataloading = time.time()
        if opt.debug and batch_idx == 3:
            break

    if training:
        to_write = opt.train_log_txt_formatter.format(
            time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
            epoch=epoch_i,
            loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
        with open(opt.train_log_filepath, "a") as f:
            f.write(to_write)
        print("Epoch time stats:")
        print("dataloading_time: max {dataloading_time.max} "
              "min {dataloading_time.min} avg {dataloading_time.avg}\n"
              "prepare_inputs_time: max {prepare_inputs_time.max} "
              "min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n"
              "model_forward_time: max {model_forward_time.max} "
              "min {model_forward_time.min} avg {model_forward_time.avg}\n"
              "model_backward_time: max {model_backward_time.max} "
              "min {model_backward_time.min} avg {model_backward_time.avg}\n"
              "".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time,
                        model_forward_time=model_forward_time, model_backward_time=model_backward_time))
    else:
        for k, v in loss_meters.items():
            opt.writer.add_scalar("Eval_Loss/{}".format(k), v.avg, epoch_i)