Example #1
0
def setup_model(cfg, device=None):
    LOGGER.info("Setup model...")
    # has to be a BertConfig instance
    model_cfg = load_json(cfg.model_config)
    model_cfg = BertConfig(**model_cfg)
    # add downstream model config
    add_attr_list = [
        "num_labels", "classifier", "cls_hidden_scale", "loss_type"
    ]
    for k in add_attr_list:
        setattr(model_cfg, k, cfg[k])

    # we separate the CNN and the transformer in order to use different optimizer for each
    # transformer still has a CNN layer inside, used to down sample grid.
    LOGGER.info("setup e2e model")
    model = ClipBert(model_cfg,
                     input_format=cfg.img_input_format,
                     detectron2_model_cfg=cfg.detectron2_model_cfg,
                     transformer_cls=ClipBertForSequenceClassification)
    if cfg.e2e_weights_path:
        LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
        load_state_dict_with_mismatch(model, cfg.e2e_weights_path)
    else:
        LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}")
        LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
        model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path,
                                 bert_weights_path=cfg.bert_weights_path)

    if cfg.freeze_cnn:
        model.freeze_cnn_backbone()

    model.to(device)

    LOGGER.info("Setup model done!")
    return model
Example #2
0
def setup_model(cfg, device=None):
    LOGGER.info("Setup model...")
    # has to be a BertConfig instance
    model_cfg = load_json(cfg.model_config)
    model_cfg = BertConfig(**model_cfg)
    # add pixel random sampling, only for pre-training
    model_cfg.pixel_random_sampling_size = cfg.pixel_random_sampling_size
    # add model-specific config
    add_attr_list = [
        "pixel_random_sampling_size",
    ]
    for k in add_attr_list:
        setattr(model_cfg, k, cfg[k])
    LOGGER.info(f"model_cfg {pprint.pformat(model_cfg.to_dict())}")

    # we separate the CNN and the transformer in order to use different optimizer for each
    # transformer still has a CNN layer inside, used to down sample grid.
    LOGGER.info("setup e2e model")
    model = ClipBert(model_cfg,
                     input_format=cfg.img_input_format,
                     detectron2_model_cfg=cfg.detectron2_model_cfg,
                     transformer_cls=ClipBertForPreTraining)
    if cfg.e2e_weights_path:
        LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
        load_state_dict_with_mismatch(model, cfg.e2e_weights_path)
    else:
        LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}")
        LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
        model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path,
                                 bert_weights_path=cfg.bert_weights_path)

    if cfg.freeze_cnn:
        model.freeze_cnn_backbone()
    model.to(device)

    LOGGER.info("Setup model done!")
    return model
Example #3
0
def start_inference(cfg):
    set_random_seed(cfg.seed)
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True

    inference_res_dir = join(
        cfg.output_dir,
        f"mc_results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
        f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
    )

    if hvd.rank() == 0:
        os.makedirs(inference_res_dir, exist_ok=True)
        save_json(cfg,
                  join(inference_res_dir, "raw_args.json"),
                  save_pretty=True)

    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    # overwrite cfg with stored_cfg,
    # but skip keys containing the keyword 'inference'
    stored_cfg_path = join(cfg.output_dir, "log/args.json")
    stored_cfg = edict(load_json(stored_cfg_path))
    for k, v in cfg.items():
        if k in stored_cfg and "inference" not in k and "output_dir" not in k:
            setattr(cfg, k, stored_cfg[k])

    # setup models
    cfg.model_config = join(cfg.output_dir, "log/model_config.json")
    e2e_weights_path = join(cfg.output_dir,
                            f"ckpt/model_step_{cfg.inference_model_step}.pt")
    cfg.e2e_weights_path = e2e_weights_path
    model = setup_model(cfg, device=device)
    model.eval()

    # FIXME separate scaling for each loss
    model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2')

    global_step = 0
    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    cfg.data_ratio = 1.

    val_loader = mk_msrvtt_mc_eval_dataloader(
        anno_path=cfg.inference_txt_db,
        lmdb_dir=cfg.inference_img_db,
        cfg=cfg,
        tokenizer=tokenizer,
    )

    LOGGER.info(cfg)
    LOGGER.info("Starting inference...")
    LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
    LOGGER.info(f"  Batch size = {cfg.inference_batch_size}")

    LOGGER.info(f'Step {global_step}: start validation')
    ret_results, ret_scores = inference_retrieval_mc(model, val_loader,
                                                     cfg.inference_txt_db, cfg)

    if hvd.rank() == 0:
        save_json(cfg,
                  join(inference_res_dir, "merged_args.json"),
                  save_pretty=True)
        save_json(ret_results,
                  join(inference_res_dir, "mc_test_results.json"),
                  save_pretty=True)
        save_json(ret_scores,
                  join(inference_res_dir, "mc_test_scores.json"),
                  save_pretty=True)
Example #4
0
def inference_retrieval_mc(model,
                           val_loader,
                           eval_file_path,
                           cfg,
                           n_options=5):
    model.eval()
    pred_id2ans = dict()
    st = time.time()
    LOGGER.info(f"Evaluate retrieval MC: {len(val_loader)}")
    if hvd.rank() == 0:
        pbar = tqdm(total=len(val_loader), desc="eval")

    for batch in val_loader:
        # compile shared text part
        question_ids = batch["question_ids"]
        bsz = len(question_ids)
        del batch["question_ids"]
        mini_batch = dict()
        for k, v in batch.items():
            if k not in ["visual_inputs", "meta"]:
                mini_batch[k] = v
        # multi-frame test, scores across frames of the same video will be pooled together
        # batch["visual_inputs"]  (B, T, C, H, W)
        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is evaluated
        num_clips = cfg.inference_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        new_visual_shape = (bsz, num_clips,
                            num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        for clip_idx in range(num_clips):
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg, n_options=n_options)
            logits.append(outputs["logits"].cpu())
        logits = torch.stack(logits)  # (num_frm, B, 1 or 2)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 1 or 2)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 1 or 2)
        elif pool_method == "lse":
            logits = logits.permute(
                1, 0,
                2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
            logits = torch.logsumexp(
                logits,
                dim=1)  # torch.exp alone might be too large and unstable
        else:
            raise ValueError(
                f"Invalid value for pool_method, "
                f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if logits.shape[1] == 2:
            probs = F.softmax(logits, dim=1)[:, 1]
        else:
            probs = torch.sigmoid(logits.squeeze())  # B
        probs = probs.view(-1, n_options)  # (B, 5)
        pred_answers = probs.max(1)[1].tolist()  # (B, )
        for qid, pred_ans in zip(question_ids, pred_answers):
            pred_id2ans[qid] = int(pred_ans)

        if hvd.rank() == 0:
            pbar.update(1)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    n_gpu = hvd.size()
    eval_dir = join(
        cfg.output_dir,
        f"results_mc_{os.path.splitext(os.path.basename(eval_file_path))[0]}")
    os.makedirs(eval_dir, exist_ok=True)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(
                    pred_id2ans,
                    join(eval_dir, f"tmp_results_mc_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                print(f"Saving exception: {e}")
                save_trial += 1

    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        pred_id2ans = []
        for rk in range(n_gpu):
            pred_id2ans.append(
                load_json(join(eval_dir, f"tmp_results_mc_rank{rk}.json")))
        pred_id2ans = merge_dicts(pred_id2ans)
        LOGGER.info('results joined')

    if hvd.rank() == 0:
        retrieval_qa_metrics = val_loader.dataset.evaluate_qa_accuracy(
            pred_id2ans, force_same=True)
        LOGGER.info(
            f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_qa_metrics}"
        )
    else:
        retrieval_qa_metrics = None

    model.train()
    return pred_id2ans, retrieval_qa_metrics
Example #5
0
def inference_retrieval(model, val_loader, eval_file_path, cfg):
    model.eval()
    retrieval_res = []  # list(dict): dict(vid_id=..., txt_id=..., score=...)
    st = time.time()
    eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size
    LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}")
    if hvd.rank() == 0:
        pbar = tqdm(total=len(val_loader), desc="eval")

    for batch in val_loader:
        # each batch contains 1 video and N (=1000) captions
        n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz)
        vid_id = batch["vid_id"]
        for idx in range(n_mini_batches):
            # compile shared text part
            mini_batch = dict()
            for k in ["text_input_ids", "text_input_mask", "labels"]:
                if batch[k] is not None:
                    mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) *
                                             eval_bsz]
                else:
                    mini_batch[k] = None
            caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) *
                                               eval_bsz]
            # bsz = len(caption_ids)
            mini_batch["n_examples_list"] = [len(caption_ids)]

            # multi-frame test, scores across frames of the same video will be pooled together
            pool_method = cfg.score_agg_func
            # could be 1, where only a single clip is evaluated
            num_clips = cfg.inference_n_clips
            num_frm = cfg.num_frm
            # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
            new_visual_shape = (1, num_clips,
                                num_frm) + batch["visual_inputs"].shape[2:]
            visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
            logits = []
            for clip_idx in range(num_clips):
                mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
                outputs = forward_step(model, mini_batch, cfg)
                logits.append(outputs["logits"].cpu())
            logits = torch.stack(logits)  # (num_frm, B, 1 or 2)
            if pool_method == "mean":
                logits = logits.mean(0)  # (B, 1 or 2)
            elif pool_method == "max":
                logits = logits.max(0)[0]  # (B, 1 or 2)
            elif pool_method == "lse":
                logits = logits.permute(1, 0, 2).contiguous(
                )  # (B, num_frm, 5), pooling will be done in CE
                logits = torch.logsumexp(
                    logits,
                    dim=1)  # torch.exp alone might be too large and unstable
            else:
                raise ValueError(
                    f"Invalid value for pool_method, "
                    f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

            if logits.shape[1] == 2:
                probs = F.softmax(logits, dim=1)[:, 1].tolist()
            else:
                probs = torch.sigmoid(logits.squeeze()).tolist()  # B
            for cap_id, score in zip(caption_ids, probs):
                retrieval_res.append(
                    dict(vid_id=vid_id, txt_id=cap_id, score=round(score, 4)))

        if hvd.rank() == 0:
            pbar.update(1)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    n_gpu = hvd.size()
    eval_dir = join(
        cfg.output_dir,
        f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}")
    os.makedirs(eval_dir, exist_ok=True)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(retrieval_res,
                          join(eval_dir, f"tmp_results_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                print(f"Saving exception: {e}")
                save_trial += 1

    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        retrieval_res = []
        for rk in range(n_gpu):
            retrieval_res.extend(
                load_json(join(eval_dir, f"tmp_results_rank{rk}.json")))
        LOGGER.info('results joined')

    if hvd.rank() == 0:
        retrieval_metrics = eval_retrieval(retrieval_res,
                                           val_loader.dataset.gt_cap_id2vid_id,
                                           val_loader.dataset.id2data)
        LOGGER.info(
            f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}"
        )
    else:
        retrieval_metrics = None

    model.train()
    return retrieval_res, retrieval_metrics
Example #6
0
def start_inference(cfg):
    set_random_seed(cfg.seed)
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True

    inference_res_dir = join(
        cfg.output_dir,
        f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
        f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
    )

    if hvd.rank() == 0:
        os.makedirs(inference_res_dir, exist_ok=True)
        save_json(cfg, join(inference_res_dir, "raw_args.json"),
                  save_pretty=True)

    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))

    # overwrite cfg with stored_cfg,
    # but skip keys containing the keyword 'inference'
    stored_cfg_path = join(cfg.output_dir, "log/args.json")
    stored_cfg = edict(load_json(stored_cfg_path))
    for k, v in cfg.items():
        if k in stored_cfg and "inference" not in k:
            setattr(cfg, k, stored_cfg[k])

    # setup models
    cfg.model_config = join(cfg.output_dir, "log/model_config.json")
    e2e_weights_path = join(
        cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt")
    cfg.e2e_weights_path = e2e_weights_path
    model = setup_model(cfg, device=device)
    model.eval()

    # FIXME separate scaling for each loss
    model = amp.initialize(
        model, enabled=cfg.fp16, opt_level='O2')

    global_step = 0
    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    cfg.data_ratio = 1.
    val_loader = mk_tgif_qa_dataloader(
        task_type=cfg.task,
        anno_path=cfg.inference_txt_db,
        lmdb_dir=cfg.inference_img_db,
        cfg=cfg, tokenizer=tokenizer,
        is_train=False,
        return_label=False
    )
    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
    val_loader = PrefetchLoader(val_loader, img_norm)

    LOGGER.info(cfg)
    LOGGER.info("Starting inference...")
    LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
    LOGGER.info(f"  Batch size = {cfg.inference_batch_size}")

    LOGGER.info(f'Step {global_step}: start validation')
    qa_results, qa_scores = validate(
        model, val_loader, cfg, global_step,
        eval_score=True)  # cfg.inference_split == "val"

    if hvd.rank() == 0:
        save_json(cfg, join(inference_res_dir, "merged_args.json"),
                  save_pretty=True)
        save_json(qa_scores, join(inference_res_dir, "scores.json"),
                  save_pretty=True)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(
                    qa_results,
                    join(inference_res_dir, f"results_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                save_trial += 1
    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        qa_results = []
        for rk in range(n_gpu):
            qa_results.extend(load_json(
                join(inference_res_dir, f"results_rank{rk}.json")))
        LOGGER.info(f'results joined')

    if hvd.rank() == 0:
        save_json(
            qa_results,
            join(inference_res_dir, f"results_all.json"))
        LOGGER.info(f'all results written')
Example #7
0
def mk_tgif_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer,
                          is_train=True, return_label=True):
    """
    Returns:
        list(dict), each dict is
            action and transition: {
                "gif_name": "tumblr_nk172bbdPI1u1lr18o1_250",
                "question": "What does the butterfly do 10 or more than 10 times ?",
                "options": ["stuff marshmallow", "holds a phone towards face",
                            "fall over", "talk", "flap wings"],
                "answer": 4
                }
            frameqa: {
                "gif_name": "tumblr_no73q2fm0I1uuf348o1_250",
                "question": "what is being placed in the white ice cream cone ?",
                "answer": "cookie",
                "answer_type": "object"
                }
            msrvtt_qa: {
                "answer": "couch",
                "question": "what are three people sitting on?",
                "video_id": "video6513",
                "answer_type": "what"
                }
    """
    raw_datalist = load_jsonl(anno_path)
    LOGGER.info(f"Loaded data size {len(raw_datalist)}")
    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
        LOGGER.info(f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}")

    datalist = []
    qid = 0
    for raw_d in raw_datalist:
        d = dict(
            question=raw_d["question"],
            vid_id=raw_d["gif_name"] if "gif_name" in raw_d else raw_d["video_id"],
            answer=raw_d["answer"],  # int or str
            question_id=qid  # be careful, it is not unique across splits
        )
        qid += 1

        if task_type in ["action", "transition"]:
            d["options"] = raw_d["options"]
        elif task_type in ["frameqa", "msrvtt_qa"]:
            d["answer_type"] = raw_d["answer_type"]

        datalist.append(d)
    LOGGER.info(f"datalist {len(datalist)}")

    grouped = defaultdict(list)  # examples grouped by image/video id
    for d in datalist:
        grouped[d["vid_id"]].append(d)
    LOGGER.info(f"grouped {len(grouped)}")

    # each group has a single image with multiple questions
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1,  # force 1 in eval,
        is_train=is_train
    )
    LOGGER.info(f"group_datalist {len(group_datalist)}")

    ans2label = load_json(cfg.ans2label_path)

    frm_sampling_strategy = cfg.frm_sampling_strategy
    if not is_train and frm_sampling_strategy == "rand":
        frm_sampling_strategy = "middle"
    dataset = ClipBertVideoQADataset(
        task_type=cfg.task,
        datalist=group_datalist,
        tokenizer=tokenizer,
        img_lmdb_dir=lmdb_dir,
        ans2label=ans2label,
        max_img_size=cfg.max_img_size,
        max_txt_len=cfg.max_txt_len,
        fps=cfg.fps,
        num_frm=cfg.num_frm,
        frm_sampling_strategy=frm_sampling_strategy,
        ensemble_n_clips=cfg.train_n_clips if is_train else cfg.inference_n_clips,
        return_label=return_label,
        is_train=is_train
    )
    LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
                f"each group {cfg.max_n_example_per_group if is_train else 1}")
    if cfg.do_inference:
        batch_size = cfg.inference_batch_size
    else:
        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    sampler = DistributedSampler(
        dataset, num_replicas=hvd.size(), rank=hvd.rank(),
        shuffle=is_train)
    vqa_collator = VideoQACollator(tokenizer=tokenizer,
                                   max_length=cfg.max_txt_len,
                                   task_type=cfg.task)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=vqa_collator.collate_batch)
    return dataloader
Example #8
0
def start_inference(cfg):
    set_random_seed(cfg.seed)
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True

    inference_res_dir = join(
        cfg.output_dir, f"results_{cfg.inference_split}"
        f"step_{cfg.inference_model_step}")
    if hvd.rank() == 0:
        os.makedirs(inference_res_dir, exist_ok=True)
        save_json(cfg,
                  join(inference_res_dir, "raw_args.json"),
                  save_pretty=True)

    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    # overwrite cfg with stored_cfg,
    # but skip keys containing the keyword 'inference'
    stored_cfg_path = join(cfg.output_dir, "log/args.json")
    stored_cfg = edict(load_json(stored_cfg_path))
    for k, v in cfg.items():
        if (k in stored_cfg and "inference" not in k and k != "output_dir"):
            value = stored_cfg[k]
            # FIXME hardcode changes
            if isinstance(value, str) and value.startswith("/data"):
                value = value.replace("/data", "/storage")
            setattr(cfg, k, value)

    # setup models
    cfg.model_config = join(cfg.output_dir, "log/model_config.json")
    cfg.detectron2_model_cfg = join(cfg.output_dir,
                                    "log/detectron2_model_cfg.yaml")
    e2e_weights_path = join(cfg.output_dir,
                            f"ckpt/model_step_{cfg.inference_model_step}.pt")
    if exists(e2e_weights_path):
        cfg.e2e_weights_path = e2e_weights_path
    else:
        cfg.bert_weights_path = join(
            f"{cfg.output_dir}/ckpt",
            f"transformer_step_{cfg.inference_model_step}.pt")
        cfg.cnn_weights_path = join(
            cfg.output_dir, f"ckpt/cnn_step_{cfg.inference_model_step}.pt")
    model = setup_model(cfg, device=device)
    model.eval()

    # FIXME separate scaling for each loss
    model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2')

    global_step = 0
    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    cfg.data_ratio = 1.
    val_loader = mk_vqa_dataloader(anno_path=cfg.inference_txt_db,
                                   img_lmdb_dir=cfg.inference_img_db,
                                   cfg=cfg,
                                   tokenizer=tokenizer,
                                   is_train=False)
    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
    val_loader = PrefetchLoader(val_loader, img_norm)

    LOGGER.info(cfg)
    LOGGER.info("Starting inference...")
    LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
    LOGGER.info(f"  Batch size = {cfg.inference_batch_size}")

    LOGGER.info(f'Step {global_step}: start validation')
    vqa_results = validate(model,
                           val_loader,
                           cfg,
                           global_step,
                           eval_score=cfg.inference_split == "val")

    if hvd.rank() == 0:
        save_json(cfg,
                  join(inference_res_dir, "merged_args.json"),
                  save_pretty=True)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(
                    vqa_results,
                    join(inference_res_dir, f"results_rank{hvd.rank()}.json"))
                break
            except Exception:
                save_trial += 1
    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        vqa_results = []
        for rk in range(n_gpu):
            vqa_results.extend(
                load_json(join(inference_res_dir, f"results_rank{rk}.json")))
        LOGGER.info('results joined')

    if hvd.rank() == 0:
        save_json(vqa_results, join(inference_res_dir, "results_all.json"))
        LOGGER.info('all results written')
Example #9
0
def mk_vqa_dataloader(anno_path, img_lmdb_dir, cfg, tokenizer, is_train=True):
    """
    Returns:
        list(dict), each dict is
        {
            "filepath": str,
            "txt": str,
        }
    """
    if isinstance(anno_path, str):
        raw_datalist = load_jsonl(anno_path)
    else:
        raw_datalist = flat_list_of_lists([load_jsonl(p) for p in anno_path])

    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]

    datalist = []
    for raw_d in raw_datalist:
        d = dict(
            txt=raw_d["question"],
            img_id=raw_d["image_id"],
            question_id=raw_d["question_id"],
        )
        if "labels" in raw_d:  # deal with test sets
            d["labels"] = raw_d["labels"]
        if "answer_type" in raw_d:
            d["answer_type"] = raw_d["answer_type"]
        datalist.append(d)

    grouped = defaultdict(list)  # examples grouped by image/video id
    for d in datalist:
        grouped[d["img_id"]].append(d)

    # each group has a single image with multiple questions
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=cfg.max_n_example_per_group
        if is_train else 1,  # force 1 in eval
        is_train=is_train,
        example_unique_key="question_id")

    ans2label = load_json(cfg.ans2label_path)
    dataset = ClipBertVQADataset(datalist=group_datalist,
                                 tokenizer=tokenizer,
                                 img_lmdb_dir=img_lmdb_dir,
                                 ans2label=ans2label,
                                 max_img_size=cfg.max_img_size,
                                 max_txt_len=cfg.max_txt_len)
    LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
                f"each group {cfg.max_n_example_per_group if is_train else 1}")
    if cfg.do_inference:
        batch_size = cfg.inference_batch_size
    else:
        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    sampler = DistributedSampler(dataset,
                                 num_replicas=hvd.size(),
                                 rank=hvd.rank(),
                                 shuffle=is_train)
    vqa_collator = VQACollator(tokenizer=tokenizer, max_length=cfg.max_txt_len)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=vqa_collator.collate_batch)
    return dataloader