예제 #1
0
 def __init__(self, opts, model, optimizer):
     if exists(f"{opts.output_dir}/log/args.json"):
         restore_opts = json.load(
             open(f'{opts.output_dir}/log/args.json', 'r'))
         with open(join(
                 opts.output_dir, 'log',
                 'restore_args.json'), 'w') as writer:
             json.dump(vars(opts), writer, indent=4)
         # assert opts == edict(restore_opts)
     # keep 2 checkpoints in case of corrupted
     self.save_path = f'{opts.output_dir}/restore.pt'
     self.backup_path = f'{opts.output_dir}/restore_backup.pt'
     self.model = model
     self.optimizer = optimizer
     self.save_steps = int(opts.save_steps_ratio * opts.num_train_steps)
     self.amp = opts.fp16
     # since saving to or loading from azure blob fails sometimes
     self.max_save_load_trial = 10
     if exists(self.save_path) or exists(self.backup_path):
         LOGGER.info('found previous checkpoint. try to resume...')
         # with retrial, as azure blob fails occasionally.
         restore_trial = 0
         while restore_trial < self.max_save_load_trial:
             LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}")
             try:
                 self.restore(opts)
                 break
             except Exception as e:
                 restore_trial += 1
     else:
         self.global_step = 0
예제 #2
0
 def __init__(self, opts, **ckpt_dict):
     if exists(opts.output_dir):
         restore_opts = json.load(open(
             f'{opts.output_dir}/log/args.json', 'r'))
         assert opts == edict(restore_opts)
     # keep 2 checkpoints in case of corrupted
     self.save_path = f'{opts.output_dir}/restore.pt'
     self.backup_path = f'{opts.output_dir}/restore_backup.pt'
     self.ckpt_dict = ckpt_dict
     self.save_steps = opts.save_steps
     self.amp = opts.fp16
     # since saving to or loading from azure blob fails sometimes
     self.max_save_load_trial = 10
     if exists(self.save_path) or exists(self.backup_path):
         LOGGER.info('found previous checkpoint. try to resume...')
         # with retrial, as azure blob fails occasionally.
         restore_trial = 0
         while restore_trial < self.max_save_load_trial:
             LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}")
             try:
                 self.restore()
                 break
             except Exception as e:
                 restore_trial += 1
     else:
         self.global_step = 0
예제 #3
0
    def __getitem__(self, index):
        if self.vis_format == "image":
            # one image/video with multiple examples
            vis_id, examples = self.datalist[index]
            img_array = self._load_img(vis_id)  # tensor, (T=1, C, H, W)
        else:  # video
            num_retries = 3  # skip error videos
            for _ in range(num_retries):
                vis_id, examples = self.datalist[index]
                img_array, _ = self._load_video(
                    vis_id)  # tensor, (T=num_frm, C, H, W)
                # Select a random video if the current video was not able to access.
                if img_array is None:
                    LOGGER.info(
                        f"Failed to load examples with video: {vis_id}. "
                        f"Will randomly sample an example as a replacement.")
                    index = random.randint(0, len(self) - 1)
                    continue
                else:
                    break
            else:
                raise RuntimeError(
                    f"Failed to fetch video after {num_retries} retries.")

        examples = [self._get_single_example(e, index) for e in examples]
        return dict(
            img=img_array,  # (T, C, H, W)
            examples=examples,
            n_examples=len(examples)  # used to create image feature copies.
        )
예제 #4
0
def setup_dataloaders(cfg, tokenizer):
    LOGGER.info("Init. train_loader and val_loader...")
    train_loaders = {}
    for db in cfg.train_datasets:
        train_loaders[db.name] = mk_captions_pretrain_dataloader(
            dataset_name=db.name,
            vis_format=db.vis_format,
            anno_path=db.txt,
            img_lmdb_dir=db.img,
            cfg=cfg,
            tokenizer=tokenizer,
            is_train=True)
        if "ratio" in db:
            train_loaders[db.name] = (train_loaders[db.name], db.ratio)

    val_loaders = {}
    for db in cfg.val_datasets:
        val_loaders[db.name] = mk_captions_pretrain_dataloader(
            dataset_name=db.name,
            vis_format=db.vis_format,
            anno_path=db.txt,
            img_lmdb_dir=db.img,
            cfg=cfg,
            tokenizer=tokenizer,
            is_train=False)
    return train_loaders, val_loaders
예제 #5
0
    def __init__(self):
        try:
            self.device = 'cuda' if config_model.use_cuda else 'cpu'
            LOGGER.info('using device: {}'.format(self.device))
            if self.device == 'cuda':
                os.environ["CUDA_VISIBLE_DEVICES"] = config_model.device_nums
            self.tokenizer = BertTokenizer(config_model.vocab_path)

            # dialogue model
            self.dialogue_model = GPT2LMHeadModel.from_pretrained(config_model.dialogue_model_path)
            self.dialogue_model.to(self.device)
            self.dialogue_model.eval()

            # mmi model
            self.mmi_model = GPT2LMHeadModel.from_pretrained(config_model.mmi_model_path)
            self.mmi_model.to(self.device)
            self.dialogue_model.eval()

            self.max_sequence_len = config_model.max_len
            self.batch_size = config_model.batch_size
            self.repetition_penalty = config_model.repetition_penalty
            self.temperature = config_model.temperature
            self.debug = config_model.debug
            self.topk = config_model.topk
            self.topp = config_model.topp


        except Exception as e:
            LOGGER.error("FAIL INIT: {}".format(str(e)))
            traceback.print_exc()
            sys.exit(-1)
예제 #6
0
def validate(model, val_loader, eval_loader, cfg, train_global_step,
             eval_filepath):
    """use eval_score=False when doing inference on test sets where answers are not available"""
    model.eval()

    loss = 0.
    n_ex = 0
    n_corrects = 0
    st = time.time()
    debug_step = 5
    for val_step, batch in enumerate(val_loader):
        # forward pass
        del batch["caption_ids"]
        outputs = forward_step(model, batch, cfg)
        targets = batch['labels']

        loss += outputs["loss"].sum().item() if isinstance(
            outputs["loss"], torch.Tensor) else 0
        n_ex += len(targets)

        if outputs["logits"].shape[1] == 2:
            n_corrects += (outputs["logits"].max(
                dim=-1)[1] == targets).sum().item()
        else:
            predictions = (torch.sigmoid(outputs["logits"]) > 0.5).long()
            predictions = predictions.view(outputs["loss"].shape[0], -1)
            targets = targets.view(outputs["loss"].shape[0], -1)
            matched = predictions[:, 0].squeeze() == targets[:, 0].squeeze()
            n_corrects += matched.sum().item()

        if cfg.debug and val_step >= debug_step:
            break

    loss = sum(all_gather_list(loss))
    n_ex = sum(all_gather_list(n_ex))
    n_corrects = sum(all_gather_list(n_corrects))

    _, retrieval_metrics = inference_retrieval(model, eval_loader,
                                               eval_filepath, cfg)

    model.train()

    if hvd.rank() == 0:
        # average loss for each example
        acc = float(n_corrects / n_ex)
        val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc}
        for ret_type, ret_m in retrieval_metrics.items():
            val_log.update({
                f"valid/{ret_type}_{k}": round(v, 4)
                for k, v in ret_m.items()
            })

        TB_LOGGER.log_scalar_dict(val_log)
        LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
                    f"itm_acc: {acc}. Retrieval res {retrieval_metrics}")
예제 #7
0
 def restore(self):
     try:
         checkpoint = torch.load(self.save_path)
     except Exception:
         checkpoint = torch.load(self.backup_path)
     self.global_step = checkpoint['global_step']
     for k in self.ckpt_dict:
         self.ckpt_dict[k].load_state_dict(_to_cuda(checkpoint[k]))
     if self.amp:
         amp.load_state_dict(checkpoint['amp_state_dict'])
     LOGGER.info(f'resume training from step {self.global_step}')
예제 #8
0
 def step(self):
     self.global_step += 1
     if self.global_step % self.save_steps == 0:
         # with retrial, as azure blob fails occasionally.
         save_trial = 0
         while save_trial < self.max_save_load_trial:
             LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}")
             try:
                 self.save()
                 break
             except Exception as e:
                 save_trial += 1
예제 #9
0
 def restore(self, opts):
     try:
         checkpoint = torch.load(self.save_path)
     except Exception:
         checkpoint = torch.load(self.backup_path)
     self.global_step = checkpoint['global_step']
     self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict']))
     self.optimizer.load_state_dict(
         _to_cuda(checkpoint['optim_state_dict']))
     if self.amp:
         amp.load_state_dict(checkpoint['amp_state_dict'])
     LOGGER.info(f'resume training from step {self.global_step}')
예제 #10
0
def main(args):

    LOGGER.info('start')

    app = tornado.web.Application(
            handlers=urls,
            debug=False)
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.bind(args.port)
    http_server.start(args.threads)

    #http_server.listen(options.port)
    print("START")
    tornado.ioloop.IOLoop.instance().start()
    print("INSTANT START")
예제 #11
0
def setup_dataloaders(cfg, tokenizer):
    LOGGER.info("Init. train_loader and val_loader...")
    train_loader = mk_video_ret_dataloader(anno_path=cfg.train_datasets[0].txt,
                                           lmdb_dir=cfg.train_datasets[0].img,
                                           cfg=cfg,
                                           tokenizer=tokenizer,
                                           is_train=True)
    val_loader = mk_video_ret_dataloader(anno_path=cfg.val_datasets[0].txt,
                                         lmdb_dir=cfg.val_datasets[0].img,
                                         cfg=cfg,
                                         tokenizer=tokenizer,
                                         is_train=False)
    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
    train_loader = PrefetchLoader(train_loader, img_norm)
    val_loader = PrefetchLoader(val_loader, img_norm)
    return train_loader, val_loader
예제 #12
0
    def post(self):
        response = {'status': 0, 'data': {}, 'message': 'fail'}
        try:
            session_id = self.get_argument("sessionId")
            input_text = self.get_argument("text")
        except Exception as e:
            LOGGER.error("FAIL receive args: {}".format(str(e)))
            response['message'] = str(e)
            self.finish(response)
            return

        try:
            st = time.time()
            session_id = int(session_id)
            keeper_partition = session_id % config_instance.num_keepers
            keepers[keeper_partition].update_history(session_id=session_id,
                                                     new_input_text=input_text)
            history = keepers[keeper_partition].get_history(
                session_id=session_id)
            generate_chars = worker.generate(history)
            print(generate_chars)
            if len(generate_chars) == 0:
                response['message'] = "fail generate response text"
                self.finish(response)
            generate = "".join(generate_chars)
            keepers[keeper_partition].update_history(session_id=session_id,
                                                     new_input_text=generate)
            body_info = {
                'sessionId': session_id,
                'input': input_text,
                'output': generate
            }
            print(body_info)
            LOGGER.info(
                "receive: session_id: {}, input_text: {}, back: {}, cost: {} ms"
                .format(str(session_id), input_text, json.dumps(body_info),
                        (time.time() - st) * 1000))
            response['data'] = body_info
            response['status'] = 1
            response['message'] = 'success'
            self.finish(response)

        except Exception as e:
            LOGGER.error("FAIL make resonse: {}".format(str(e)))
            response['message'] = str(e)
            self.finish(response)
        return
예제 #13
0
 def _get_random_negative_caption(self, gt_index):
     gt_img_id, _ = self.datalist[gt_index]
     max_trials = 5
     while max_trials > 0:
         neg_index = int(random.random() * len(self))
         neg_img_id, neg_examples = self.datalist[neg_index]
         if neg_img_id == gt_img_id:
             max_trials -= 1
             continue
         else:
             break
     if max_trials == 0:
         LOGGER.info(f"gt_filepath {gt_img_id} index {gt_index}, "
                     f"neg_data filepath {neg_examples} index {neg_index}")
         raise Warning(
             f"The negative sampler cannot sample a true negative within 5 trials"
         )
     neg_data = neg_examples[int(random.random() * len(neg_examples))]
     return neg_data["txt"]
예제 #14
0
 def save(self, step, model, optimizer=None, prefix="model"):
     model_path = join(self.output_dir, f"{prefix}_step_{step}.pt")
     state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
                   for k, v in model.state_dict().items()}
     # with retrial, as azure blob fails occasionally.
     save_trial = 0
     while save_trial < self.max_save_load_trial:
         try:
             LOGGER.info(f"ModelSaver save trial NO. {save_trial}")
             torch.save(state_dict, model_path)
             if optimizer is not None:
                 optimizer_state_dict = \
                     {k: v.cpu() if isinstance(v, torch.Tensor) else v
                      for k, v in optimizer.state_dict().items()}
                 dump = {'step': step, 'optimizer': optimizer_state_dict}
                 torch.save(
                     dump,
                     f'{self.output_dir}/{prefix}_step_{step}_train_state.pt')
             break
         except Exception as e:
             save_trial += 1
예제 #15
0
def mk_msrvtt_mc_datalist(raw_datalist, cfg):
    """
    Args:
        raw_datalist: list(dict)
        cfg:

    Returns:

    """
    LOGGER.info(f"Loaded data size {len(raw_datalist)}")
    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
        LOGGER.info(
            f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}"
        )

    datalist = []
    for raw_d in raw_datalist:
        d = dict(
            id=raw_d["qid"],
            vid_id=raw_d["clip_name"],
            answer=raw_d["answer"],
            options=raw_d["options"],
        )
        datalist.append(d)
    LOGGER.info(f"datalist {len(datalist)}")
    return datalist
예제 #16
0
def mk_video_ret_datalist(raw_datalist, cfg):
    """
    Args:
        raw_datalist: list(dict)
        cfg:

    Returns:

    """
    LOGGER.info(f"Loaded data size {len(raw_datalist)}")
    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
        LOGGER.info(
            f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}"
        )

    datalist = []
    qid = 0
    for raw_d in raw_datalist:
        d = dict(id=qid, txt=raw_d["caption"], vid_id=raw_d["clip_name"])
        qid += 1
        datalist.append(d)
    LOGGER.info(f"datalist {len(datalist)}")
    return datalist
예제 #17
0
def compare_dict_difference(dict1, dict2, dict1_name="dict1",
                            dict2_name="dict2",
                            print_value_diff=True, verbose=False):
    """
    Args:
        dict1:
        dict2:
        dict1_name:
        dict2_name:
        print_value_diff: bool, output dict value difference within shared keys
            for dict1 and dict2. In effect only when verbose == True
        verbose:
    """
    keys1 = set(dict1.keys())
    keys2 = set(dict2.keys())
    shared_keys = keys1.intersection(keys2)
    keys1_unique = keys1.difference(shared_keys)
    keys2_unique = keys2.difference(shared_keys)
    key_diff_list = list(keys1_unique) + list(keys2_unique)

    # value difference in the shared keys in dict1 and dict2
    value_diff_dict = {}
    for k in shared_keys:
        if dict1[k] != dict2[k]:
            value_diff_dict[k] = [(dict1_name, dict1[k]), (dict2_name, dict2[k])]

    if verbose:
        LOGGER.info("=" * 30 + "key difference")
        LOGGER.info(f"keys in {dict1_name} but not in {dict2_name}: "
                    f"total {len(keys1_unique)}, {sorted(keys1_unique)}")
        LOGGER.info(f"keys in {dict2_name} but not in {dict1_name}: "
                    f"total {len(keys2_unique)}, {sorted(keys2_unique)}")

    if verbose and print_value_diff:

        LOGGER.info("=" * 30 + "value difference")
        LOGGER.info(f"{json.dumps(value_diff_dict, indent=4)}")

    return value_diff_dict, key_diff_list
예제 #18
0
def mk_video_ret_dataloader(anno_path,
                            lmdb_dir,
                            cfg,
                            tokenizer,
                            is_train=True):
    """"""
    raw_datalist = load_jsonl(anno_path)
    datalist = mk_video_ret_datalist(raw_datalist, cfg)
    grouped = defaultdict(list)  # examples grouped by image/video id
    for d in datalist:
        grouped[d["vid_id"]].append(d)
    LOGGER.info(f"grouped {len(grouped)}")

    # each group has a single image with multiple questions
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=cfg.max_n_example_per_group
        if is_train else 1,  # force 1 in eval,
        is_train=is_train)
    LOGGER.info(f"group_datalist {len(group_datalist)}")

    frm_sampling_strategy = cfg.frm_sampling_strategy
    if not is_train and frm_sampling_strategy == "rand":
        frm_sampling_strategy = "middle"
    dataset = ClipBertVideoRetrievalDataset(
        datalist=group_datalist,
        tokenizer=tokenizer,
        img_lmdb_dir=lmdb_dir,
        max_img_size=cfg.max_img_size,
        max_txt_len=cfg.max_txt_len,
        fps=cfg.fps,
        num_frm=cfg.num_frm,
        frm_sampling_strategy=frm_sampling_strategy,
        itm_neg_size=cfg.itm_neg_size,
        ensemble_n_clips=cfg.train_n_clips,
        random_sample_clips=cfg.random_sample_clips)
    LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
                f"each group {cfg.max_n_example_per_group if is_train else 1}")
    if cfg.do_inference:
        batch_size = cfg.inference_batch_size
    else:
        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    sampler = DistributedSampler(dataset,
                                 num_replicas=hvd.size(),
                                 rank=hvd.rank(),
                                 shuffle=is_train)
    vqa_collator = VideoRetrievalCollator(tokenizer=tokenizer,
                                          max_length=cfg.max_txt_len)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=vqa_collator.collate_batch)
    return dataloader
예제 #19
0
def save_training_meta(args):
    # args is an EasyDict object, treat it the same as a normal dict
    os.makedirs(join(args.output_dir, 'log'), exist_ok=True)
    os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True)

    # training args
    save_args_path = join(args.output_dir, 'log', 'args.json')
    save_json(args, save_args_path, save_pretty=True)

    # model args
    model_config = json.load(open(args.model_config))
    save_model_config_path = join(args.output_dir, 'log', 'model_config.json')
    save_json(model_config, save_model_config_path, save_pretty=True)

    # save a copy of the codebase. !!!Do not store heavy file in your codebase when using it.
    code_dir = dirname(dirname(dirname(os.path.realpath(__file__))))
    code_zip_filename = os.path.join(args.output_dir, "code.zip")
    LOGGER.info(f"Saving code from {code_dir} to {code_zip_filename}...")
    make_zipfile(code_dir, code_zip_filename,
                 enclosing_dir="code",
                 exclude_dirs_substring="results",
                 exclude_dirs=["results", "debug_results", "__pycache__"],
                 exclude_extensions=[".pyc", ".ipynb", ".swap"])
    LOGGER.info(f"Saving code done.")
예제 #20
0
def start_inference(cfg):
    set_random_seed(cfg.seed)
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True

    inference_res_dir = join(
        cfg.output_dir,
        f"mc_results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
        f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
    )

    if hvd.rank() == 0:
        os.makedirs(inference_res_dir, exist_ok=True)
        save_json(cfg,
                  join(inference_res_dir, "raw_args.json"),
                  save_pretty=True)

    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    # overwrite cfg with stored_cfg,
    # but skip keys containing the keyword 'inference'
    stored_cfg_path = join(cfg.output_dir, "log/args.json")
    stored_cfg = edict(load_json(stored_cfg_path))
    for k, v in cfg.items():
        if k in stored_cfg and "inference" not in k and "output_dir" not in k:
            setattr(cfg, k, stored_cfg[k])

    # setup models
    cfg.model_config = join(cfg.output_dir, "log/model_config.json")
    e2e_weights_path = join(cfg.output_dir,
                            f"ckpt/model_step_{cfg.inference_model_step}.pt")
    cfg.e2e_weights_path = e2e_weights_path
    model = setup_model(cfg, device=device)
    model.eval()

    # FIXME separate scaling for each loss
    model = amp.initialize(model, enabled=cfg.fp16, opt_level='O2')

    global_step = 0
    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    cfg.data_ratio = 1.

    val_loader = mk_msrvtt_mc_eval_dataloader(
        anno_path=cfg.inference_txt_db,
        lmdb_dir=cfg.inference_img_db,
        cfg=cfg,
        tokenizer=tokenizer,
    )

    LOGGER.info(cfg)
    LOGGER.info("Starting inference...")
    LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
    LOGGER.info(f"  Batch size = {cfg.inference_batch_size}")

    LOGGER.info(f'Step {global_step}: start validation')
    ret_results, ret_scores = inference_retrieval_mc(model, val_loader,
                                                     cfg.inference_txt_db, cfg)

    if hvd.rank() == 0:
        save_json(cfg,
                  join(inference_res_dir, "merged_args.json"),
                  save_pretty=True)
        save_json(ret_results,
                  join(inference_res_dir, "mc_test_results.json"),
                  save_pretty=True)
        save_json(ret_scores,
                  join(inference_res_dir, "mc_test_scores.json"),
                  save_pretty=True)
예제 #21
0
def inference_retrieval_mc(model,
                           val_loader,
                           eval_file_path,
                           cfg,
                           n_options=5):
    model.eval()
    pred_id2ans = dict()
    st = time.time()
    LOGGER.info(f"Evaluate retrieval MC: {len(val_loader)}")
    if hvd.rank() == 0:
        pbar = tqdm(total=len(val_loader), desc="eval")

    for batch in val_loader:
        # compile shared text part
        question_ids = batch["question_ids"]
        bsz = len(question_ids)
        del batch["question_ids"]
        mini_batch = dict()
        for k, v in batch.items():
            if k not in ["visual_inputs", "meta"]:
                mini_batch[k] = v
        # multi-frame test, scores across frames of the same video will be pooled together
        # batch["visual_inputs"]  (B, T, C, H, W)
        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is evaluated
        num_clips = cfg.inference_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        new_visual_shape = (bsz, num_clips,
                            num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        for clip_idx in range(num_clips):
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg, n_options=n_options)
            logits.append(outputs["logits"].cpu())
        logits = torch.stack(logits)  # (num_frm, B, 1 or 2)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 1 or 2)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 1 or 2)
        elif pool_method == "lse":
            logits = logits.permute(
                1, 0,
                2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
            logits = torch.logsumexp(
                logits,
                dim=1)  # torch.exp alone might be too large and unstable
        else:
            raise ValueError(
                f"Invalid value for pool_method, "
                f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if logits.shape[1] == 2:
            probs = F.softmax(logits, dim=1)[:, 1]
        else:
            probs = torch.sigmoid(logits.squeeze())  # B
        probs = probs.view(-1, n_options)  # (B, 5)
        pred_answers = probs.max(1)[1].tolist()  # (B, )
        for qid, pred_ans in zip(question_ids, pred_answers):
            pred_id2ans[qid] = int(pred_ans)

        if hvd.rank() == 0:
            pbar.update(1)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    n_gpu = hvd.size()
    eval_dir = join(
        cfg.output_dir,
        f"results_mc_{os.path.splitext(os.path.basename(eval_file_path))[0]}")
    os.makedirs(eval_dir, exist_ok=True)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(
                    pred_id2ans,
                    join(eval_dir, f"tmp_results_mc_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                print(f"Saving exception: {e}")
                save_trial += 1

    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        pred_id2ans = []
        for rk in range(n_gpu):
            pred_id2ans.append(
                load_json(join(eval_dir, f"tmp_results_mc_rank{rk}.json")))
        pred_id2ans = merge_dicts(pred_id2ans)
        LOGGER.info('results joined')

    if hvd.rank() == 0:
        retrieval_qa_metrics = val_loader.dataset.evaluate_qa_accuracy(
            pred_id2ans, force_same=True)
        LOGGER.info(
            f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_qa_metrics}"
        )
    else:
        retrieval_qa_metrics = None

    model.train()
    return pred_id2ans, retrieval_qa_metrics
예제 #22
0
def setup_model(cfg, device=None):
    LOGGER.info("Setup model...")
    # has to be a BertConfig instance
    model_cfg = load_json(cfg.model_config)
    model_cfg = BertConfig(**model_cfg)
    # add downstream model config
    add_attr_list = [
        "num_labels",
        "classifier",
        "cls_hidden_scale",
        "loss_type",
        "margin",
    ]
    for k in add_attr_list:
        setattr(model_cfg, k, cfg[k])
    transformer_model_cls = ClipBertForVideoTextRetrieval

    # we separate the CNN and the transformer in order to use different optimizer for each
    # transformer still has a CNN layer inside, used to down sample grid.
    LOGGER.info("setup e2e model")
    model = ClipBert(model_cfg,
                     input_format=cfg.img_input_format,
                     detectron2_model_cfg=cfg.detectron2_model_cfg,
                     transformer_cls=transformer_model_cls)
    if cfg.e2e_weights_path:
        LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
        load_state_dict_with_mismatch(model, cfg.e2e_weights_path)
    else:
        LOGGER.info(f"Loading cnn weights from {cfg.detectron2_weights_path}")
        LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
        model.load_separate_ckpt(cnn_weights_path=cfg.detectron2_weights_path,
                                 bert_weights_path=cfg.bert_weights_path)
    if cfg.freeze_cnn:
        model.freeze_cnn_backbone()
    model.to(device)

    LOGGER.info("Setup model done!")
    return model
예제 #23
0
def inference_retrieval(model, val_loader, eval_file_path, cfg):
    model.eval()
    retrieval_res = []  # list(dict): dict(vid_id=..., txt_id=..., score=...)
    st = time.time()
    eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size
    LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}")
    if hvd.rank() == 0:
        pbar = tqdm(total=len(val_loader), desc="eval")

    for batch in val_loader:
        # each batch contains 1 video and N (=1000) captions
        n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz)
        vid_id = batch["vid_id"]
        for idx in range(n_mini_batches):
            # compile shared text part
            mini_batch = dict()
            for k in ["text_input_ids", "text_input_mask", "labels"]:
                if batch[k] is not None:
                    mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) *
                                             eval_bsz]
                else:
                    mini_batch[k] = None
            caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) *
                                               eval_bsz]
            # bsz = len(caption_ids)
            mini_batch["n_examples_list"] = [len(caption_ids)]

            # multi-frame test, scores across frames of the same video will be pooled together
            pool_method = cfg.score_agg_func
            # could be 1, where only a single clip is evaluated
            num_clips = cfg.inference_n_clips
            num_frm = cfg.num_frm
            # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
            new_visual_shape = (1, num_clips,
                                num_frm) + batch["visual_inputs"].shape[2:]
            visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
            logits = []
            for clip_idx in range(num_clips):
                mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
                outputs = forward_step(model, mini_batch, cfg)
                logits.append(outputs["logits"].cpu())
            logits = torch.stack(logits)  # (num_frm, B, 1 or 2)
            if pool_method == "mean":
                logits = logits.mean(0)  # (B, 1 or 2)
            elif pool_method == "max":
                logits = logits.max(0)[0]  # (B, 1 or 2)
            elif pool_method == "lse":
                logits = logits.permute(1, 0, 2).contiguous(
                )  # (B, num_frm, 5), pooling will be done in CE
                logits = torch.logsumexp(
                    logits,
                    dim=1)  # torch.exp alone might be too large and unstable
            else:
                raise ValueError(
                    f"Invalid value for pool_method, "
                    f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

            if logits.shape[1] == 2:
                probs = F.softmax(logits, dim=1)[:, 1].tolist()
            else:
                probs = torch.sigmoid(logits.squeeze()).tolist()  # B
            for cap_id, score in zip(caption_ids, probs):
                retrieval_res.append(
                    dict(vid_id=vid_id, txt_id=cap_id, score=round(score, 4)))

        if hvd.rank() == 0:
            pbar.update(1)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    n_gpu = hvd.size()
    eval_dir = join(
        cfg.output_dir,
        f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}")
    os.makedirs(eval_dir, exist_ok=True)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(retrieval_res,
                          join(eval_dir, f"tmp_results_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                print(f"Saving exception: {e}")
                save_trial += 1

    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        retrieval_res = []
        for rk in range(n_gpu):
            retrieval_res.extend(
                load_json(join(eval_dir, f"tmp_results_rank{rk}.json")))
        LOGGER.info('results joined')

    if hvd.rank() == 0:
        retrieval_metrics = eval_retrieval(retrieval_res,
                                           val_loader.dataset.gt_cap_id2vid_id,
                                           val_loader.dataset.id2data)
        LOGGER.info(
            f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}"
        )
    else:
        retrieval_metrics = None

    model.train()
    return retrieval_res, retrieval_metrics
예제 #24
0
def start_training(cfg):
    set_random_seed(cfg.seed)

    n_gpu = hvd.size()
    cfg.n_gpu = n_gpu
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    model = setup_model(cfg, device=device)
    model.train()
    optimizer = setup_e2e_optimizer(model, cfg)

    # Horovod: (optional) compression algorithm.compressin
    compression = hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    #  Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=cfg.fp16,
                                      opt_level='O2',
                                      keep_batchnorm_fp32=True)

    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    train_loader, val_loader = setup_dataloaders(cfg, tokenizer)
    eval_loader = mk_video_ret_eval_dataloader(
        anno_path=cfg.val_datasets[0].txt,
        lmdb_dir=cfg.val_datasets[0].img,
        cfg=cfg,
        tokenizer=tokenizer,
    )

    # compute the number of steps and update cfg
    total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group
    total_train_batch_size = int(n_gpu * cfg.train_batch_size *
                                 cfg.gradient_accumulation_steps *
                                 cfg.max_n_example_per_group)
    cfg.num_train_steps = int(
        math.ceil(1. * cfg.num_train_epochs * total_n_examples /
                  total_train_batch_size))

    cfg.valid_steps = int(
        math.ceil(1. * cfg.num_train_steps / cfg.num_valid /
                  cfg.min_valid_steps)) * cfg.min_valid_steps
    actual_num_valid = int(
        math.floor(1. * cfg.num_train_steps / cfg.valid_steps)) + 1

    # restore
    restorer = TrainingRestorer(cfg, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        LOGGER.info("Saving training meta...")
        save_training_meta(cfg)
        path = join(cfg.output_dir, 'log', "detectron2_model_cfg.yaml")
        with open(path, "w") as f:
            f.write(model.cnn.config_file)
        LOGGER.info("Saving training done...")
        TB_LOGGER.create(join(cfg.output_dir, 'log'))
        pbar = tqdm(total=cfg.num_train_steps)
        model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
        add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)

    LOGGER.info(cfg)
    LOGGER.info("Starting training...")
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info(
        f"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
    LOGGER.info(f"  max_n_example_per_group = {cfg.max_n_example_per_group}")
    LOGGER.info(f"  Accumulate steps = {cfg.gradient_accumulation_steps}")
    LOGGER.info(
        f"  Total batch size = #GPUs * Single-GPU batch size * "
        f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}"
    )
    LOGGER.info(f"  Total #epochs = {cfg.num_train_epochs}")
    LOGGER.info(f"  Total #steps = {cfg.num_train_steps}")
    LOGGER.info(
        f"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times"
    )

    # quick hack for amp delay_unscale bug
    with optimizer.skip_synchronize():
        optimizer.zero_grad()
        if global_step == 0:
            optimizer.step()
    debug_step = 3
    running_loss = RunningMeter('train_loss')

    for step, batch in enumerate(InfiniteIterator(train_loader)):
        # forward pass
        del batch["caption_ids"]
        mini_batch = dict()
        for k, v in batch.items():
            if k != "visual_inputs":
                mini_batch[k] = v

        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is used
        num_clips = cfg.train_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        bsz = batch["visual_inputs"].shape[0]
        new_visual_shape = (bsz, num_clips,
                            num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        for clip_idx in range(num_clips):
            # (B, num_frm, C, H, W)
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg)
            logits.append(outputs["logits"])
            # the losses are cross entropy and mse, no need to * num_labels

        logits = torch.stack(logits)  # (num_frm, B, 5)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 5)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 5)
        elif pool_method == "lse":
            logits = logits.permute(
                1, 0,
                2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
        else:
            raise ValueError(
                f"Invalid value for pool_method, "
                f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if pool_method == "lse":
            out = torch.logsumexp(logits.view(logits.shape[0], -1), dim=-1, keepdim=True) \
                - torch.logsumexp(logits, dim=1)
            loss = torch.gather(out, -1, batch["labels"].view(-1, 1))
        else:
            _, loss = model.transformer.calc_loss(
                logits,
                batch["labels"],
                sample_size=len(batch["n_examples_list"]))
        loss = loss.mean()

        running_loss(loss.item())
        # backward pass
        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()
            zero_none_grad(model)
            optimizer.synchronize()

        # optimizer
        if (step + 1) % cfg.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            n_epoch = int(1. * total_train_batch_size * global_step /
                          total_n_examples)
            # learning rate scheduling transformer
            lr_this_step_transformer = get_lr_sched(
                global_step,
                cfg.decay,
                cfg.learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.step_decay_epochs,
                multi_step_epoch=n_epoch)

            # learning rate scheduling cnn
            lr_this_step_cnn = get_lr_sched(
                global_step,
                cfg.cnn_lr_decay,
                cfg.cnn_learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.cnn_step_decay_epochs,
                multi_step_epoch=n_epoch)

            # Hardcoded param group length
            assert len(optimizer.param_groups) == 8
            for pg_n, param_group in enumerate(optimizer.param_groups):
                if pg_n in [0, 1]:
                    param_group['lr'] = (cfg.transformer_lr_mul *
                                         lr_this_step_transformer)
                elif pg_n in [2, 3]:
                    param_group['lr'] = lr_this_step_transformer
                elif pg_n in [4, 5]:
                    param_group['lr'] = (cfg.cnn_lr_mul * lr_this_step_cnn)
                else:
                    param_group['lr'] = lr_this_step_cnn
            TB_LOGGER.add_scalar("train/lr_transformer",
                                 lr_this_step_transformer, global_step)
            TB_LOGGER.add_scalar("train/lr_cnn", lr_this_step_cnn, global_step)

            TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)

            # update model params
            if cfg.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            cfg.grad_norm)
                TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
            TB_LOGGER.step()

            # Check if there is None grad
            none_grads = [
                p[0] for p in model.named_parameters()
                if p[1].requires_grad and p[1].grad is None
            ]

            assert len(none_grads) == 0, f"{none_grads}"

            with optimizer.skip_synchronize():
                optimizer.step()
                optimizer.zero_grad()
            restorer.step()
            pbar.update(1)

            # checkpoint
            if global_step % cfg.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                validate(model,
                         val_loader,
                         eval_loader,
                         cfg,
                         global_step,
                         eval_filepath=cfg.val_datasets[0].txt)
                model_saver.save(step=global_step, model=model)
        if global_step >= cfg.num_train_steps:
            break

        if cfg.debug and global_step >= debug_step:
            break

    if global_step % cfg.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        validate(model,
                 val_loader,
                 eval_loader,
                 cfg,
                 global_step,
                 eval_filepath=cfg.val_datasets[0].txt)
        model_saver.save(step=global_step, model=model)
예제 #25
0
def mk_captions_pretrain_dataloader(dataset_name,
                                    vis_format,
                                    anno_path,
                                    img_lmdb_dir,
                                    cfg,
                                    tokenizer,
                                    is_train=True):
    # make a list(dict), where each dict {vis_id: int, txt: str}
    if dataset_name == "coco_cap":
        grouped = mk_vis_txt_pair_datalist(anno_path,
                                           data_ratio=cfg.data_ratio,
                                           vis_id_key="coco_id",
                                           txt_key="caption")
    elif dataset_name == "vg_cap":
        grouped = mk_vis_txt_pair_datalist(anno_path,
                                           data_ratio=cfg.data_ratio,
                                           vis_id_key="vg_id",
                                           txt_key="caption")
    else:
        raise ValueError("Invalid dataset_name")

    # each group has a single image with multiple questions
    max_n_example_per_group = cfg.max_n_example_per_group \
        if vis_format == "image" else 1  # single element group for video.
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=max_n_example_per_group,
        is_train=is_train)

    frm_sampling_strategy = cfg.frm_sampling_strategy
    if not is_train and frm_sampling_strategy == "rand":
        frm_sampling_strategy = "middle"
    dataset = ClipBertPretrainDataset(
        datalist=group_datalist,
        tokenizer=tokenizer,
        img_lmdb_dir=img_lmdb_dir,
        max_img_size=cfg.max_img_size,
        max_txt_len=cfg.max_txt_len,
        itm_neg_prob=cfg.itm_neg_prob,
        use_itm=cfg.use_itm,
        fps=cfg.fps,
        num_frm=cfg.num_frm,
        frm_sampling_strategy=frm_sampling_strategy,
        vis_format=vis_format)
    LOGGER.info(f"[{dataset_name}] is_train {is_train} "
                f"dataset size {len(dataset)}, "
                f"group size {max_n_example_per_group}")
    batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    # hardcode video batch size to be 1 / num_frm of the image batch size.
    # so that video input image size could be similar to image batch size.
    batch_size = batch_size if vis_format == "image" else int(batch_size /
                                                              cfg.num_frm)
    sampler = DistributedSampler(dataset,
                                 num_replicas=hvd.size(),
                                 rank=hvd.rank(),
                                 shuffle=is_train)
    data_collator = PretrainCollator(tokenizer=tokenizer,
                                     mlm=cfg.use_mlm,
                                     mlm_probability=0.15,
                                     max_length=cfg.max_txt_len,
                                     is_train=is_train)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=data_collator.collate_batch)
    return dataloader
예제 #26
0
def validate(model, val_loader, cfg, train_global_step, eval_score=True):
    """use eval_score=False when doing inference on test sets where answers are not available"""
    model.eval()

    loss = 0.
    n_ex = 0
    qa_results = []
    st = time.time()
    debug_step = 5
    pbar = tqdm(total=len(val_loader))
    for val_step, batch in enumerate(val_loader):
        # forward pass
        question_ids = batch["question_ids"]
        bsz = len(question_ids)
        # used to make visual feature copies
        del batch["question_ids"]
        # add visual part into the mini batch and perform inference
        mini_batch = dict()
        for k, v in batch.items():
            if k != "visual_inputs":
                mini_batch[k] = v

        n_ex += len(question_ids)
        # multi-frame test, scores across frames of the same video will be pooled together
        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is evaluated
        num_clips = cfg.inference_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        losses = []
        for clip_idx in range(num_clips):
            # (B, num_frm, C, H, W)
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg)
            logits.append(outputs["logits"].cpu())
            _loss = outputs["loss"].sum().item() if isinstance(
                outputs["loss"], torch.Tensor) else 0
            losses.append(_loss)
        loss += (sum(losses) / num_clips)

        logits = torch.stack(logits)  # (num_frm, B, 5)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 5)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 5)
        elif pool_method == "lse":
            logits = logits.permute(1, 0, 2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
            logits = torch.logsumexp(logits, dim=1)  # torch.exp alone might be too large and unstable
        else:
            raise ValueError(f"Invalid value for pool_method, "
                             f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]:
            # cross entropy
            pred_labels = logits.max(dim=-1)[1].data.tolist()
        else:
            # mse
            preds = (logits + 0.5).long().clamp(min=1, max=10)
            pred_labels = preds.data.squeeze().tolist()
        for qid, pred_label in zip(question_ids, pred_labels):
            qa_results.append(dict(
                question_id=qid,
                answer=pred_label,
                data=val_loader.dataset.qid2data[qid]
            ))
        pbar.update(1)
        if cfg.debug and val_step >= debug_step:
            break

    if cfg.debug:
        LOGGER.info(qa_results[:10])
    n_ex_per_rank = all_gather_list(n_ex)
    loss = sum(all_gather_list(loss))
    n_ex = sum(all_gather_list(n_ex))
    # average loss for each example
    val_log = {f'valid/loss': float(loss / n_ex)}
    if eval_score:
        LOGGER.info(f"QA Task [{cfg.task}], "
                    f"{len(qa_results)} qa_results,"
                    f"3 examples here: {qa_results[:3]}")
        vqa_scores = val_loader.dataset.evaluate_tgif_qa(qa_results)
        # print(f"{hvd.rank()}: {vqa_scores}")

        # Gather scores
        scores_per_rank = all_gather_list(vqa_scores)
        gathered_scores = {}
        if "ratios" in scores_per_rank[0]:
            gathered_ratios = {
                k: [0, 0] for k, _ in scores_per_rank[0]["ratios"].items()}
            # Gather ratios
            for rank_id in range(len(n_ex_per_rank)):
                current_ratios = scores_per_rank[rank_id]["ratios"]
                for k, v in current_ratios.items():
                    gathered_ratios[k][1] += v[1]
            for k, v in gathered_ratios.items():
                gathered_ratios[k][0] = get_rounded_percentage(
                    1. * v[1] / n_ex)
            gathered_scores["ratios"] = gathered_ratios

        # FIXME: Gather scores become complicated due to np.mean and dict format.
        for scores_k, _ in vqa_scores.items():
            if "ratio" in scores_k:
                continue
            gathered_v = 0
            for rank_id, n in enumerate(n_ex_per_rank):
                curr_acc, curr_n_ex = 0, 0
                if "overall" in scores_k:
                    curr_acc = scores_per_rank[rank_id][scores_k] * n
                else:
                    if "ratios" in scores_per_rank[0]:
                        curr_n_ex = scores_per_rank[
                                rank_id]["ratios"][
                                    scores_k.replace("acc", "ratio")][1]
                        curr_acc = scores_per_rank[rank_id][
                            scores_k] * curr_n_ex
                gathered_v += curr_acc
            if "overall" in scores_k:
                gathered_v = gathered_v * 1. / n_ex
            else:
                if "ratios" in scores_per_rank[0]:
                    _num = gathered_ratios[
                        scores_k.replace("acc", "ratio")][1]
                    gathered_v = gathered_v * 1. / _num if _num != 0 else 0
            if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]:
                gathered_scores[scores_k] = get_rounded_percentage(
                    gathered_v)
            else:
                gathered_scores[scores_k] = round(gathered_v, 2)

        for k, v in gathered_scores.items():
            if "ratio" not in k:
                val_log[f'valid/{k}'] = v
    else:
        LOGGER.info("eval_score = False, no scores are calculated.")
        gathered_scores = 0

    TB_LOGGER.log_scalar_dict(val_log)
    LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
                f"{gathered_scores}")

    model.train()
    return qa_results, gathered_scores
예제 #27
0
def mk_tgif_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer,
                          is_train=True, return_label=True):
    """
    Returns:
        list(dict), each dict is
            action and transition: {
                "gif_name": "tumblr_nk172bbdPI1u1lr18o1_250",
                "question": "What does the butterfly do 10 or more than 10 times ?",
                "options": ["stuff marshmallow", "holds a phone towards face",
                            "fall over", "talk", "flap wings"],
                "answer": 4
                }
            frameqa: {
                "gif_name": "tumblr_no73q2fm0I1uuf348o1_250",
                "question": "what is being placed in the white ice cream cone ?",
                "answer": "cookie",
                "answer_type": "object"
                }
            msrvtt_qa: {
                "answer": "couch",
                "question": "what are three people sitting on?",
                "video_id": "video6513",
                "answer_type": "what"
                }
    """
    raw_datalist = load_jsonl(anno_path)
    LOGGER.info(f"Loaded data size {len(raw_datalist)}")
    if cfg.data_ratio != 1.0:
        random.shuffle(raw_datalist)
        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
        LOGGER.info(f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}")

    datalist = []
    qid = 0
    for raw_d in raw_datalist:
        d = dict(
            question=raw_d["question"],
            vid_id=raw_d["gif_name"] if "gif_name" in raw_d else raw_d["video_id"],
            answer=raw_d["answer"],  # int or str
            question_id=qid  # be careful, it is not unique across splits
        )
        qid += 1

        if task_type in ["action", "transition"]:
            d["options"] = raw_d["options"]
        elif task_type in ["frameqa", "msrvtt_qa"]:
            d["answer_type"] = raw_d["answer_type"]

        datalist.append(d)
    LOGGER.info(f"datalist {len(datalist)}")

    grouped = defaultdict(list)  # examples grouped by image/video id
    for d in datalist:
        grouped[d["vid_id"]].append(d)
    LOGGER.info(f"grouped {len(grouped)}")

    # each group has a single image with multiple questions
    group_datalist = mk_input_group(
        grouped,
        max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1,  # force 1 in eval,
        is_train=is_train
    )
    LOGGER.info(f"group_datalist {len(group_datalist)}")

    ans2label = load_json(cfg.ans2label_path)

    frm_sampling_strategy = cfg.frm_sampling_strategy
    if not is_train and frm_sampling_strategy == "rand":
        frm_sampling_strategy = "middle"
    dataset = ClipBertVideoQADataset(
        task_type=cfg.task,
        datalist=group_datalist,
        tokenizer=tokenizer,
        img_lmdb_dir=lmdb_dir,
        ans2label=ans2label,
        max_img_size=cfg.max_img_size,
        max_txt_len=cfg.max_txt_len,
        fps=cfg.fps,
        num_frm=cfg.num_frm,
        frm_sampling_strategy=frm_sampling_strategy,
        ensemble_n_clips=cfg.train_n_clips if is_train else cfg.inference_n_clips,
        return_label=return_label,
        is_train=is_train
    )
    LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
                f"each group {cfg.max_n_example_per_group if is_train else 1}")
    if cfg.do_inference:
        batch_size = cfg.inference_batch_size
    else:
        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
    sampler = DistributedSampler(
        dataset, num_replicas=hvd.size(), rank=hvd.rank(),
        shuffle=is_train)
    vqa_collator = VideoQACollator(tokenizer=tokenizer,
                                   max_length=cfg.max_txt_len,
                                   task_type=cfg.task)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=cfg.n_workers,
                            pin_memory=cfg.pin_mem,
                            collate_fn=vqa_collator.collate_batch)
    return dataloader
예제 #28
0
def start_inference(cfg):
    set_random_seed(cfg.seed)
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True

    inference_res_dir = join(
        cfg.output_dir,
        f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
        f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
    )

    if hvd.rank() == 0:
        os.makedirs(inference_res_dir, exist_ok=True)
        save_json(cfg, join(inference_res_dir, "raw_args.json"),
                  save_pretty=True)

    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))

    # overwrite cfg with stored_cfg,
    # but skip keys containing the keyword 'inference'
    stored_cfg_path = join(cfg.output_dir, "log/args.json")
    stored_cfg = edict(load_json(stored_cfg_path))
    for k, v in cfg.items():
        if k in stored_cfg and "inference" not in k:
            setattr(cfg, k, stored_cfg[k])

    # setup models
    cfg.model_config = join(cfg.output_dir, "log/model_config.json")
    e2e_weights_path = join(
        cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt")
    cfg.e2e_weights_path = e2e_weights_path
    model = setup_model(cfg, device=device)
    model.eval()

    # FIXME separate scaling for each loss
    model = amp.initialize(
        model, enabled=cfg.fp16, opt_level='O2')

    global_step = 0
    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    cfg.data_ratio = 1.
    val_loader = mk_tgif_qa_dataloader(
        task_type=cfg.task,
        anno_path=cfg.inference_txt_db,
        lmdb_dir=cfg.inference_img_db,
        cfg=cfg, tokenizer=tokenizer,
        is_train=False,
        return_label=False
    )
    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
    val_loader = PrefetchLoader(val_loader, img_norm)

    LOGGER.info(cfg)
    LOGGER.info("Starting inference...")
    LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
    LOGGER.info(f"  Batch size = {cfg.inference_batch_size}")

    LOGGER.info(f'Step {global_step}: start validation')
    qa_results, qa_scores = validate(
        model, val_loader, cfg, global_step,
        eval_score=True)  # cfg.inference_split == "val"

    if hvd.rank() == 0:
        save_json(cfg, join(inference_res_dir, "merged_args.json"),
                  save_pretty=True)
        save_json(qa_scores, join(inference_res_dir, "scores.json"),
                  save_pretty=True)

    # ###### Saving with Horovod ####################
    # dummy sync
    _ = None
    all_gather_list(_)
    if n_gpu > 1:
        # with retrial, as azure blob fails occasionally.
        max_save_load_trial = 10
        save_trial = 0
        while save_trial < max_save_load_trial:
            try:
                LOGGER.info(f"Save results trial NO. {save_trial}")
                save_json(
                    qa_results,
                    join(inference_res_dir, f"results_rank{hvd.rank()}.json"))
                break
            except Exception as e:
                save_trial += 1
    # dummy sync
    _ = None
    all_gather_list(_)
    # join results
    if n_gpu > 1 and hvd.rank() == 0:
        qa_results = []
        for rk in range(n_gpu):
            qa_results.extend(load_json(
                join(inference_res_dir, f"results_rank{rk}.json")))
        LOGGER.info(f'results joined')

    if hvd.rank() == 0:
        save_json(
            qa_results,
            join(inference_res_dir, f"results_all.json"))
        LOGGER.info(f'all results written')
예제 #29
0
파일: main.py 프로젝트: kosohae/NLP
SEED = 42
DATA = 'TREC'
CUDA = False
DEBUG = False
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODE = 'static'  # nonstatic[]
WORD_VECTORS = 'rand'  # choices = [rand, word2vec]
DIM = 300

#  set seed
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if use multi-GPU

LOGGER.info(MODE)
LOGGER.info(WORD_VECTORS)


#########################
# Custom class Setting
#########################
class GSST(SST):
    urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip']
    dirname = 'trees'
    name = 'sst'

    @staticmethod
    def sort_key(ex):
        return len(ex.text)
예제 #30
0
def load_state_dict_with_mismatch(model, loaded_state_dict_or_path):
    """operated in-place, no need to return `model`"""

    if isinstance(loaded_state_dict_or_path, str):
        loaded_state_dict = torch.load(
            loaded_state_dict_or_path, map_location="cpu")
    else:
        loaded_state_dict = loaded_state_dict_or_path
    model_keys = set([k for k in list(model.state_dict().keys())])
    load_keys = set(loaded_state_dict.keys())

    toload = {}
    mismatched_shape_keys = []
    for k in model_keys:
        if k in load_keys:
            if model.state_dict()[k].shape != loaded_state_dict[k].shape:
                mismatched_shape_keys.append(k)
            else:
                toload[k] = loaded_state_dict[k]

    LOGGER.info("You can ignore the keys with `num_batches_tracked` or from task heads")
    LOGGER.info("Keys in loaded but not in model:")
    diff_keys = load_keys.difference(model_keys)
    LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}")
    LOGGER.info("Keys in model but not in loaded:")
    diff_keys = model_keys.difference(load_keys)
    LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}")
    LOGGER.info("Keys in model and loaded, but shape mismatched:")
    LOGGER.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")
    model.load_state_dict(toload, strict=False)