Exemple #1
0
def validate(model, val_loader, eval_loader, cfg, train_global_step,
             eval_filepath):
    """use eval_score=False when doing inference on test sets where answers are not available"""
    model.eval()

    loss = 0.
    n_ex = 0
    n_corrects = 0
    st = time.time()
    debug_step = 5
    for val_step, batch in enumerate(val_loader):
        # forward pass
        del batch["caption_ids"]
        outputs = forward_step(model, batch, cfg)
        targets = batch['labels']

        loss += outputs["loss"].sum().item() if isinstance(
            outputs["loss"], torch.Tensor) else 0
        n_ex += len(targets)

        if outputs["logits"].shape[1] == 2:
            n_corrects += (outputs["logits"].max(
                dim=-1)[1] == targets).sum().item()
        else:
            predictions = (torch.sigmoid(outputs["logits"]) > 0.5).long()
            predictions = predictions.view(outputs["loss"].shape[0], -1)
            targets = targets.view(outputs["loss"].shape[0], -1)
            matched = predictions[:, 0].squeeze() == targets[:, 0].squeeze()
            n_corrects += matched.sum().item()

        if cfg.debug and val_step >= debug_step:
            break

    loss = sum(all_gather_list(loss))
    n_ex = sum(all_gather_list(n_ex))
    n_corrects = sum(all_gather_list(n_corrects))

    _, retrieval_metrics = inference_retrieval(model, eval_loader,
                                               eval_filepath, cfg)

    model.train()

    if hvd.rank() == 0:
        # average loss for each example
        acc = float(n_corrects / n_ex)
        val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc}
        for ret_type, ret_m in retrieval_metrics.items():
            val_log.update({
                f"valid/{ret_type}_{k}": round(v, 4)
                for k, v in ret_m.items()
            })

        TB_LOGGER.log_scalar_dict(val_log)
        LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
                    f"itm_acc: {acc}. Retrieval res {retrieval_metrics}")
Exemple #2
0
def start_training(cfg):
    set_random_seed(cfg.seed)

    n_gpu = hvd.size()
    cfg.n_gpu = n_gpu
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    model = setup_model(cfg, device=device)
    model.train()
    optimizer = setup_e2e_optimizer(model, cfg)

    # Horovod: (optional) compression algorithm.compressin
    compression = hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    #  Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=cfg.fp16,
                                      opt_level='O2',
                                      keep_batchnorm_fp32=True)

    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    train_loader, val_loader = setup_dataloaders(cfg, tokenizer)
    eval_loader = mk_video_ret_eval_dataloader(
        anno_path=cfg.val_datasets[0].txt,
        lmdb_dir=cfg.val_datasets[0].img,
        cfg=cfg,
        tokenizer=tokenizer,
    )

    # compute the number of steps and update cfg
    total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group
    total_train_batch_size = int(n_gpu * cfg.train_batch_size *
                                 cfg.gradient_accumulation_steps *
                                 cfg.max_n_example_per_group)
    cfg.num_train_steps = int(
        math.ceil(1. * cfg.num_train_epochs * total_n_examples /
                  total_train_batch_size))

    cfg.valid_steps = int(
        math.ceil(1. * cfg.num_train_steps / cfg.num_valid /
                  cfg.min_valid_steps)) * cfg.min_valid_steps
    actual_num_valid = int(
        math.floor(1. * cfg.num_train_steps / cfg.valid_steps)) + 1

    # restore
    restorer = TrainingRestorer(cfg, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        LOGGER.info("Saving training meta...")
        save_training_meta(cfg)
        path = join(cfg.output_dir, 'log', "detectron2_model_cfg.yaml")
        with open(path, "w") as f:
            f.write(model.cnn.config_file)
        LOGGER.info("Saving training done...")
        TB_LOGGER.create(join(cfg.output_dir, 'log'))
        pbar = tqdm(total=cfg.num_train_steps)
        model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
        add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)

    LOGGER.info(cfg)
    LOGGER.info("Starting training...")
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info(
        f"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
    LOGGER.info(f"  max_n_example_per_group = {cfg.max_n_example_per_group}")
    LOGGER.info(f"  Accumulate steps = {cfg.gradient_accumulation_steps}")
    LOGGER.info(
        f"  Total batch size = #GPUs * Single-GPU batch size * "
        f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}"
    )
    LOGGER.info(f"  Total #epochs = {cfg.num_train_epochs}")
    LOGGER.info(f"  Total #steps = {cfg.num_train_steps}")
    LOGGER.info(
        f"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times"
    )

    # quick hack for amp delay_unscale bug
    with optimizer.skip_synchronize():
        optimizer.zero_grad()
        if global_step == 0:
            optimizer.step()
    debug_step = 3
    running_loss = RunningMeter('train_loss')

    for step, batch in enumerate(InfiniteIterator(train_loader)):
        # forward pass
        del batch["caption_ids"]
        mini_batch = dict()
        for k, v in batch.items():
            if k != "visual_inputs":
                mini_batch[k] = v

        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is used
        num_clips = cfg.train_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        bsz = batch["visual_inputs"].shape[0]
        new_visual_shape = (bsz, num_clips,
                            num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        for clip_idx in range(num_clips):
            # (B, num_frm, C, H, W)
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg)
            logits.append(outputs["logits"])
            # the losses are cross entropy and mse, no need to * num_labels

        logits = torch.stack(logits)  # (num_frm, B, 5)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 5)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 5)
        elif pool_method == "lse":
            logits = logits.permute(
                1, 0,
                2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
        else:
            raise ValueError(
                f"Invalid value for pool_method, "
                f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if pool_method == "lse":
            out = torch.logsumexp(logits.view(logits.shape[0], -1), dim=-1, keepdim=True) \
                - torch.logsumexp(logits, dim=1)
            loss = torch.gather(out, -1, batch["labels"].view(-1, 1))
        else:
            _, loss = model.transformer.calc_loss(
                logits,
                batch["labels"],
                sample_size=len(batch["n_examples_list"]))
        loss = loss.mean()

        running_loss(loss.item())
        # backward pass
        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()
            zero_none_grad(model)
            optimizer.synchronize()

        # optimizer
        if (step + 1) % cfg.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            n_epoch = int(1. * total_train_batch_size * global_step /
                          total_n_examples)
            # learning rate scheduling transformer
            lr_this_step_transformer = get_lr_sched(
                global_step,
                cfg.decay,
                cfg.learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.step_decay_epochs,
                multi_step_epoch=n_epoch)

            # learning rate scheduling cnn
            lr_this_step_cnn = get_lr_sched(
                global_step,
                cfg.cnn_lr_decay,
                cfg.cnn_learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.cnn_step_decay_epochs,
                multi_step_epoch=n_epoch)

            # Hardcoded param group length
            assert len(optimizer.param_groups) == 8
            for pg_n, param_group in enumerate(optimizer.param_groups):
                if pg_n in [0, 1]:
                    param_group['lr'] = (cfg.transformer_lr_mul *
                                         lr_this_step_transformer)
                elif pg_n in [2, 3]:
                    param_group['lr'] = lr_this_step_transformer
                elif pg_n in [4, 5]:
                    param_group['lr'] = (cfg.cnn_lr_mul * lr_this_step_cnn)
                else:
                    param_group['lr'] = lr_this_step_cnn
            TB_LOGGER.add_scalar("train/lr_transformer",
                                 lr_this_step_transformer, global_step)
            TB_LOGGER.add_scalar("train/lr_cnn", lr_this_step_cnn, global_step)

            TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)

            # update model params
            if cfg.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            cfg.grad_norm)
                TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
            TB_LOGGER.step()

            # Check if there is None grad
            none_grads = [
                p[0] for p in model.named_parameters()
                if p[1].requires_grad and p[1].grad is None
            ]

            assert len(none_grads) == 0, f"{none_grads}"

            with optimizer.skip_synchronize():
                optimizer.step()
                optimizer.zero_grad()
            restorer.step()
            pbar.update(1)

            # checkpoint
            if global_step % cfg.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                validate(model,
                         val_loader,
                         eval_loader,
                         cfg,
                         global_step,
                         eval_filepath=cfg.val_datasets[0].txt)
                model_saver.save(step=global_step, model=model)
        if global_step >= cfg.num_train_steps:
            break

        if cfg.debug and global_step >= debug_step:
            break

    if global_step % cfg.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        validate(model,
                 val_loader,
                 eval_loader,
                 cfg,
                 global_step,
                 eval_filepath=cfg.val_datasets[0].txt)
        model_saver.save(step=global_step, model=model)
Exemple #3
0
def validate(model, val_loader, cfg, train_global_step, eval_score=True):
    """use eval_score=False when doing inference on test sets where answers are not available"""
    model.eval()

    loss = 0.
    n_ex = 0
    qa_results = []
    st = time.time()
    debug_step = 5
    pbar = tqdm(total=len(val_loader))
    for val_step, batch in enumerate(val_loader):
        # forward pass
        question_ids = batch["question_ids"]
        bsz = len(question_ids)
        # used to make visual feature copies
        del batch["question_ids"]
        # add visual part into the mini batch and perform inference
        mini_batch = dict()
        for k, v in batch.items():
            if k != "visual_inputs":
                mini_batch[k] = v

        n_ex += len(question_ids)
        # multi-frame test, scores across frames of the same video will be pooled together
        pool_method = cfg.score_agg_func
        # could be 1, where only a single clip is evaluated
        num_clips = cfg.inference_n_clips
        num_frm = cfg.num_frm
        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
        new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
        visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
        logits = []
        losses = []
        for clip_idx in range(num_clips):
            # (B, num_frm, C, H, W)
            mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
            mini_batch["n_examples_list"] = batch["n_examples_list"]
            outputs = forward_step(model, mini_batch, cfg)
            logits.append(outputs["logits"].cpu())
            _loss = outputs["loss"].sum().item() if isinstance(
                outputs["loss"], torch.Tensor) else 0
            losses.append(_loss)
        loss += (sum(losses) / num_clips)

        logits = torch.stack(logits)  # (num_frm, B, 5)
        if pool_method == "mean":
            logits = logits.mean(0)  # (B, 5)
        elif pool_method == "max":
            logits = logits.max(0)[0]  # (B, 5)
        elif pool_method == "lse":
            logits = logits.permute(1, 0, 2).contiguous()  # (B, num_frm, 5), pooling will be done in CE
            logits = torch.logsumexp(logits, dim=1)  # torch.exp alone might be too large and unstable
        else:
            raise ValueError(f"Invalid value for pool_method, "
                             f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")

        if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]:
            # cross entropy
            pred_labels = logits.max(dim=-1)[1].data.tolist()
        else:
            # mse
            preds = (logits + 0.5).long().clamp(min=1, max=10)
            pred_labels = preds.data.squeeze().tolist()
        for qid, pred_label in zip(question_ids, pred_labels):
            qa_results.append(dict(
                question_id=qid,
                answer=pred_label,
                data=val_loader.dataset.qid2data[qid]
            ))
        pbar.update(1)
        if cfg.debug and val_step >= debug_step:
            break

    if cfg.debug:
        LOGGER.info(qa_results[:10])
    n_ex_per_rank = all_gather_list(n_ex)
    loss = sum(all_gather_list(loss))
    n_ex = sum(all_gather_list(n_ex))
    # average loss for each example
    val_log = {f'valid/loss': float(loss / n_ex)}
    if eval_score:
        LOGGER.info(f"QA Task [{cfg.task}], "
                    f"{len(qa_results)} qa_results,"
                    f"3 examples here: {qa_results[:3]}")
        vqa_scores = val_loader.dataset.evaluate_tgif_qa(qa_results)
        # print(f"{hvd.rank()}: {vqa_scores}")

        # Gather scores
        scores_per_rank = all_gather_list(vqa_scores)
        gathered_scores = {}
        if "ratios" in scores_per_rank[0]:
            gathered_ratios = {
                k: [0, 0] for k, _ in scores_per_rank[0]["ratios"].items()}
            # Gather ratios
            for rank_id in range(len(n_ex_per_rank)):
                current_ratios = scores_per_rank[rank_id]["ratios"]
                for k, v in current_ratios.items():
                    gathered_ratios[k][1] += v[1]
            for k, v in gathered_ratios.items():
                gathered_ratios[k][0] = get_rounded_percentage(
                    1. * v[1] / n_ex)
            gathered_scores["ratios"] = gathered_ratios

        # FIXME: Gather scores become complicated due to np.mean and dict format.
        for scores_k, _ in vqa_scores.items():
            if "ratio" in scores_k:
                continue
            gathered_v = 0
            for rank_id, n in enumerate(n_ex_per_rank):
                curr_acc, curr_n_ex = 0, 0
                if "overall" in scores_k:
                    curr_acc = scores_per_rank[rank_id][scores_k] * n
                else:
                    if "ratios" in scores_per_rank[0]:
                        curr_n_ex = scores_per_rank[
                                rank_id]["ratios"][
                                    scores_k.replace("acc", "ratio")][1]
                        curr_acc = scores_per_rank[rank_id][
                            scores_k] * curr_n_ex
                gathered_v += curr_acc
            if "overall" in scores_k:
                gathered_v = gathered_v * 1. / n_ex
            else:
                if "ratios" in scores_per_rank[0]:
                    _num = gathered_ratios[
                        scores_k.replace("acc", "ratio")][1]
                    gathered_v = gathered_v * 1. / _num if _num != 0 else 0
            if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa"]:
                gathered_scores[scores_k] = get_rounded_percentage(
                    gathered_v)
            else:
                gathered_scores[scores_k] = round(gathered_v, 2)

        for k, v in gathered_scores.items():
            if "ratio" not in k:
                val_log[f'valid/{k}'] = v
    else:
        LOGGER.info("eval_score = False, no scores are calculated.")
        gathered_scores = 0

    TB_LOGGER.log_scalar_dict(val_log)
    LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
                f"{gathered_scores}")

    model.train()
    return qa_results, gathered_scores
Exemple #4
0
def start_training():
    cfg = shared_configs.get_pretraining_args()
    set_random_seed(cfg.seed)

    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True
    LOGGER.info(f"device: {device} n_gpu: {n_gpu}, "
                f"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}")

    model = setup_model(cfg, device=device)
    model.train()

    optimizer = setup_e2e_optimizer(model, cfg)

    # Horovod: (optional) compression algorithm.compressin
    compression = hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    #  Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=cfg.fp16,
                                      opt_level='O2',
                                      keep_batchnorm_fp32=True)

    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer)
    train_loader = MetaLoader(train_loaders,
                              accum_steps=cfg.gradient_accumulation_steps,
                              distributed=n_gpu > 1)
    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
    train_loader = PrefetchLoader(train_loader, img_norm)
    val_loaders = {
        k: PrefetchLoader(v, img_norm)
        for k, v in val_loaders.items()
    }

    # compute the number of steps and update cfg
    total_train_batch_size = int(n_gpu * cfg.train_batch_size *
                                 cfg.gradient_accumulation_steps *
                                 cfg.max_n_example_per_group)
    total_n_epochs = cfg.num_train_epochs
    cfg.num_train_steps = int(
        math.ceil(1. * train_loader.n_batches_in_epoch * total_n_epochs /
                  (n_gpu * cfg.gradient_accumulation_steps)))
    cfg.valid_steps = int(
        math.ceil(1. * cfg.num_train_steps / cfg.num_valid /
                  cfg.min_valid_steps)) * cfg.min_valid_steps
    actual_num_valid = int(
        math.floor(1. * cfg.num_train_steps / cfg.valid_steps)) + 1

    # restore
    restorer = TrainingRestorer(cfg, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        LOGGER.info("Saving training meta...")
        save_training_meta(cfg)
        path = join(cfg.output_dir, 'log', "detectron2_model_cfg.yaml")
        with open(path, "w") as f:
            f.write(model.cnn.config_file)
        LOGGER.info("Saving training done...")
        TB_LOGGER.create(join(cfg.output_dir, 'log'))
        pbar = tqdm(total=cfg.num_train_steps)
        model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
        add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)

    LOGGER.info(cfg)
    LOGGER.info("Starting training...")
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info(
        f"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
    LOGGER.info(f"  max_n_example_per_group = {cfg.max_n_example_per_group}")
    LOGGER.info(f"  Accumulate steps = {cfg.gradient_accumulation_steps}")
    LOGGER.info(
        f"  Total batch size = #GPUs * Single-GPU batch size * "
        f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}"
    )
    LOGGER.info(
        f"  Total #batches - single epoch = {train_loader.n_batches_in_epoch}."
    )
    LOGGER.info(f"  Total #steps = {cfg.num_train_steps}")
    LOGGER.info(f"  Total #epochs = {total_n_epochs}.")
    LOGGER.info(
        f"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times"
    )

    # quick hack for amp delay_unscale bug
    with optimizer.skip_synchronize():
        optimizer.zero_grad()
        if global_step == 0:
            optimizer.step()
    debug_step = 5

    tasks = []
    for name, flag in zip(["mlm", "itm"], [cfg.use_mlm, cfg.use_itm]):
        if flag:
            tasks.append(name)
    task2loss = {t: RunningMeter(f'train_loss/{t}') for t in tasks}
    task2loss["loss"] = RunningMeter('train_loss/loss')
    for step, (task, batch) in enumerate(train_loader):
        # forward pass
        outputs = forward_step(cfg, model, batch)
        mlm_loss, itm_loss = 0, 0
        if cfg.use_mlm:
            mlm_loss = outputs["mlm_loss"].mean()
            task2loss["mlm"](mlm_loss.item())
        if cfg.use_itm:
            itm_loss = outputs["itm_loss"].mean()
            task2loss["itm"](itm_loss.item())

        loss = mlm_loss + itm_loss
        task2loss["loss"](loss.item())

        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()
            zero_none_grad(model)
            optimizer.synchronize()

        # optimizer
        if (step + 1) % cfg.gradient_accumulation_steps == 0:
            global_step += 1
            TB_LOGGER.log_scalar_dict({
                l.name: l.val
                for l in task2loss.values() if l.val is not None
            })
            n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps *
                          global_step / train_loader.n_batches_in_epoch)
            # learning rate scheduling transformer
            lr_this_step_transformer = get_lr_sched(
                global_step,
                cfg.decay,
                cfg.learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.step_decay_epochs,
                multi_step_epoch=n_epoch)

            # learning rate scheduling cnn
            lr_this_step_cnn = get_lr_sched(
                global_step,
                cfg.cnn_lr_decay,
                cfg.cnn_learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.cnn_step_decay_epochs,
                multi_step_epoch=n_epoch)

            # Hardcoded param group length
            assert len(optimizer.param_groups) == 8
            for pg_n, param_group in enumerate(optimizer.param_groups):
                if pg_n in [0, 1]:
                    param_group['lr'] = (cfg.transformer_lr_mul *
                                         lr_this_step_transformer)
                elif pg_n in [2, 3]:
                    param_group['lr'] = lr_this_step_transformer
                elif pg_n in [4, 5]:
                    param_group['lr'] = (cfg.cnn_lr_mul * lr_this_step_cnn)
                else:
                    param_group['lr'] = lr_this_step_cnn
            TB_LOGGER.add_scalar("train/lr_transformer",
                                 lr_this_step_transformer, global_step)
            TB_LOGGER.add_scalar("train/lr_cnn", lr_this_step_cnn, global_step)

            # update model params
            if cfg.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            cfg.grad_norm)
                TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
            TB_LOGGER.step()

            # Check if there is None grad
            none_grads = [
                p[0] for p in model.named_parameters()
                if p[1].requires_grad and p[1].grad is None
            ]

            assert len(none_grads) == 0, f"{none_grads}"

            with optimizer.skip_synchronize():
                optimizer.step()
                optimizer.zero_grad()
            restorer.step()
            pbar.update(1)

            # checkpoint
            if global_step % cfg.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                validate(model, val_loaders, cfg)
                model_saver.save(step=global_step, model=model)
        if global_step >= cfg.num_train_steps:
            break

        if cfg.debug and global_step >= debug_step:
            break

    if global_step % cfg.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        validate(model, val_loaders, cfg)
        model_saver.save(step=global_step, model=model)
Exemple #5
0
def validate(model, val_loader, cfg):
    model.eval()

    mlm_loss = 0
    n_mlm_tokens = 0
    n_mlm_corrects = 0
    itm_loss = 0
    n_itm_ex = 0
    n_itm_corrects = 0
    st = time.time()
    val_log = {
        'valid/mlm_loss': 0,
        'valid/mlm_acc': 0,
        'valid/itm_loss': 0,
        'valid/itm_acc': 0
    }
    debug_step = 5
    val_loaders = val_loader if isinstance(val_loader, dict) else {
        "unnamed_val_loader": val_loader
    }
    LOGGER.info(f"In total {len(val_loaders)} val loaders")
    for loader_name, val_loader in val_loaders.items():
        LOGGER.info(f"Loop val_loader {loader_name}.")
        for val_step, batch in enumerate(val_loader):
            # use iter to reset MetaLoader
            # forward pass
            outputs = forward_step(cfg, model, batch)

            # mlm
            mlm_labels = outputs["mlm_labels"]
            if cfg.use_mlm:
                mlm_loss += outputs["mlm_loss"].sum().item()
                mlm_mask = mlm_labels != -100  # (B, Lt)  -100 is the ignored label for cross entropy
                n_mlm_tokens += mlm_mask.sum().item()
                n_mlm_corrects += (outputs["mlm_scores"][mlm_mask].max(
                    dim=-1)[1] == mlm_labels[mlm_mask]).sum().item()

            # itm
            if cfg.use_itm:
                itm_loss += outputs["itm_loss"].sum().item()
                n_itm_ex += len(outputs["itm_labels"])
                n_itm_corrects += (outputs["itm_scores"].max(
                    dim=-1)[1] == outputs["itm_labels"]).sum().item()

            if cfg.debug and val_step >= debug_step:
                break
    # Gather across all processes
    mlm_loss = sum(all_gather_list(mlm_loss))
    n_mlm_corrects = sum(all_gather_list(n_mlm_corrects))
    n_mlm_tokens = sum(all_gather_list(n_mlm_tokens))
    itm_loss = sum(all_gather_list(itm_loss))
    n_itm_corrects = sum(all_gather_list(n_itm_corrects))
    n_itm_ex = sum(all_gather_list(n_itm_ex))

    if n_mlm_tokens != 0:
        val_log.update({
            'valid/mlm_loss': float(mlm_loss / n_mlm_tokens),
            'valid/mlm_acc': float(n_mlm_corrects / n_mlm_tokens)
        })
    if n_itm_ex != 0:
        val_log.update({
            'valid/itm_loss': float(itm_loss / n_itm_ex),
            'valid/itm_acc': float(n_itm_corrects / n_itm_ex)
        })

    TB_LOGGER.log_scalar_dict(val_log)
    LOGGER.info(
        f"validation finished in {int(time.time() - st)} seconds, "
        f"[mlm_acc (per token)]: {val_log['valid/mlm_acc'] * 100:.2f} "
        f"[itm_acc (per example)]: {val_log['valid/itm_acc'] * 100:.2f} ")
    model.train()
    return val_log
Exemple #6
0
def start_training(cfg):
    set_random_seed(cfg.seed)

    n_gpu = hvd.size()
    cfg.n_gpu = n_gpu
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    if hvd.rank() != 0:
        LOGGER.disabled = True
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              bool(cfg.fp16)))

    model = setup_model(cfg, device=device)
    model.train()
    optimizer = setup_e2e_optimizer(model, cfg)

    # Horovod: (optional) compression algorithm.compressin
    compression = hvd.Compression.none
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression)

    #  Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=cfg.fp16,
                                      opt_level='O2',
                                      keep_batchnorm_fp32=True)

    # prepare data
    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
    train_loader, val_loader = setup_dataloaders(cfg, tokenizer)

    # compute the number of steps and update cfg
    total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group
    total_train_batch_size = int(n_gpu * cfg.train_batch_size *
                                 cfg.gradient_accumulation_steps *
                                 cfg.max_n_example_per_group)
    cfg.num_train_steps = int(
        math.ceil(1. * cfg.num_train_epochs * total_n_examples /
                  total_train_batch_size))
    cfg.valid_steps = int(
        math.ceil(1. * cfg.num_train_steps / cfg.num_valid /
                  cfg.min_valid_steps)) * cfg.min_valid_steps
    actual_num_valid = int(
        math.floor(1. * cfg.num_train_steps / cfg.valid_steps)) + 1

    # restore
    restorer = TrainingRestorer(cfg, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        LOGGER.info("Saving training meta...")
        save_training_meta(cfg)
        path = join(cfg.output_dir, 'log', "detectron2_model_cfg.yaml")
        with open(path, "w") as f:
            f.write(model.cnn.config_file)
        LOGGER.info("Saving training done...")
        TB_LOGGER.create(join(cfg.output_dir, 'log'))
        model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
        add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
        pbar = tqdm(total=cfg.num_train_steps)
    else:
        LOGGER.disabled = True
        model_saver = NoOp()
        restorer = NoOp()
        pbar = NoOp()

    if global_step > 0:
        pbar.update(global_step)

    LOGGER.info(cfg)
    LOGGER.info("Starting training...")
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info(
        f"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
    LOGGER.info(f"  max_n_example_per_group = {cfg.max_n_example_per_group}")
    LOGGER.info(f"  Accumulate steps = {cfg.gradient_accumulation_steps}")
    LOGGER.info(
        f"  Total batch size = #GPUs * Single-GPU batch size * "
        f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}"
    )
    LOGGER.info(f"  Total #epochs = {cfg.num_train_epochs}")
    LOGGER.info(f"  Total #steps = {cfg.num_train_steps}")
    LOGGER.info(
        f"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times"
    )

    # quick hack for amp delay_unscale bug
    with optimizer.skip_synchronize():
        optimizer.zero_grad()
        if global_step == 0:
            optimizer.step()
    debug_step = 3
    running_loss = RunningMeter('train_loss')
    for step, batch in enumerate(InfiniteIterator(train_loader)):
        # forward pass
        outputs, question_ids = forward_step(model, batch)
        loss = outputs["loss"].mean()
        loss = loss.float() * cfg.num_labels
        running_loss(loss.item())
        # backward pass
        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()
            zero_none_grad(model)
            optimizer.synchronize()

        # optimizer
        if (step + 1) % cfg.gradient_accumulation_steps == 0:
            global_step += 1
            TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)

            n_epoch = int(1. * total_train_batch_size * global_step /
                          total_n_examples)
            # learning rate scheduling transformer
            lr_this_step_transformer = get_lr_sched(
                global_step,
                cfg.decay,
                cfg.learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.step_decay_epochs,
                multi_step_epoch=n_epoch)

            # learning rate scheduling cnn
            lr_this_step_cnn = get_lr_sched(
                global_step,
                cfg.cnn_lr_decay,
                cfg.cnn_learning_rate,
                cfg.num_train_steps,
                warmup_ratio=cfg.warmup_ratio,
                decay_epochs=cfg.cnn_step_decay_epochs,
                multi_step_epoch=n_epoch)

            # Hardcoded param group length
            assert len(optimizer.param_groups) == 8
            for pg_n, param_group in enumerate(optimizer.param_groups):
                if pg_n in [0, 1]:
                    param_group['lr'] = (cfg.transformer_lr_mul *
                                         lr_this_step_transformer)
                elif pg_n in [2, 3]:
                    param_group['lr'] = lr_this_step_transformer
                elif pg_n in [4, 5]:
                    param_group['lr'] = (cfg.cnn_lr_mul * lr_this_step_cnn)
                else:
                    param_group['lr'] = lr_this_step_cnn
            TB_LOGGER.add_scalar("train/lr_transformer",
                                 lr_this_step_transformer, global_step)
            TB_LOGGER.add_scalar("train/lr_cnn", lr_this_step_cnn, global_step)

            # update model params
            if cfg.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            cfg.grad_norm)
                TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
            TB_LOGGER.step()

            # Check if there is None grad
            none_grads = [
                p[0] for p in model.named_parameters()
                if p[1].requires_grad and p[1].grad is None
            ]

            assert len(none_grads) == 0, f"{none_grads}"

            with optimizer.skip_synchronize():
                optimizer.step()
                optimizer.zero_grad()
            restorer.step()
            pbar.update(1)

            # checkpoint
            if global_step % cfg.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                vqa_results = validate(model, val_loader, cfg, global_step)
                model_saver.save(step=global_step, model=model)
        if global_step >= cfg.num_train_steps:
            break

        if cfg.debug and global_step >= debug_step:
            break

    if global_step % cfg.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        vqa_results = validate(model, val_loader, cfg, global_step)
        model_saver.save(step=global_step, model=model)
Exemple #7
0
def validate(model, val_loader, cfg, train_global_step, eval_score=True):
    """use eval_score=False when doing inference on test sets where answers are not available"""
    model.eval()
    loss = 0.
    n_ex = 0
    vqa_results = []
    st = time.time()
    debug_step = 5
    for val_step, batch in enumerate(val_loader):
        # forward pass
        outputs, question_ids = forward_step(model, batch)

        loss += outputs["loss"].sum().item() if isinstance(
            outputs["loss"], torch.Tensor) else 0
        n_ex += len(question_ids)
        pred_labels = outputs["logits"].max(dim=-1)[1].data.tolist()
        for qid, pred_label in zip(question_ids, pred_labels):
            vqa_results.append(
                dict(question_id=qid,
                     answer=val_loader.dataset.label2ans[pred_label]))

        if cfg.debug and val_step >= debug_step:
            break

    if cfg.debug:
        LOGGER.info(vqa_results[:10])
    n_ex_per_rank = all_gather_list(n_ex)
    loss = sum(all_gather_list(loss))
    n_ex = sum(all_gather_list(n_ex))
    val_log = {'valid/loss': float(loss / n_ex)}
    if eval_score:
        LOGGER.info(f"Evaluate VQA scores for {len(vqa_results)} vqa_results,"
                    f"3 examples here: {vqa_results[:3]}")
        vqa_scores = val_loader.dataset.evaluate_vqa(vqa_results)

        # Gather scores
        scores_per_rank = all_gather_list(vqa_scores)
        gathered_scores = {}
        gathered_ratios = {
            k: [0, 0]
            for k, _ in scores_per_rank[0]["ratios"].items()
        }
        # Gather ratios
        for rank_id in range(len(n_ex_per_rank)):
            current_ratios = scores_per_rank[rank_id]["ratios"]
            for k, v in current_ratios.items():
                gathered_ratios[k][1] += v[1]
        for k, v in gathered_ratios.items():
            gathered_ratios[k][0] = get_rounded_percentage(1. * v[1] / n_ex)

        # FIXME: Gather scores become complicated due to np.mean and dict format.
        for scores_k, _ in vqa_scores.items():
            if "ratio" in scores_k:
                continue
            gathered_v = 0
            for rank_id, n in enumerate(n_ex_per_rank):
                if "overall" in scores_k:
                    curr_acc = scores_per_rank[rank_id][scores_k] * n
                else:
                    curr_n_ex = scores_per_rank[rank_id]["ratios"][
                        scores_k.replace("acc", "ratio")][1]
                    curr_acc = scores_per_rank[rank_id][scores_k] * curr_n_ex
                gathered_v += curr_acc
            if "overall" in scores_k:
                gathered_v = gathered_v * 1. / n_ex
            else:
                gathered_v = gathered_v * 1. / gathered_ratios[
                    scores_k.replace("acc", "ratio")][1]
            gathered_scores[scores_k] = get_rounded_percentage(gathered_v)
        gathered_scores["ratios"] = gathered_ratios

        for k, v in gathered_scores.items():
            if "ratio" not in k:
                val_log[f'valid/{k}'] = v
    else:
        LOGGER.info("Seems you are doing inference on test set,"
                    "no scores are calculated.")
        gathered_scores = 0

    TB_LOGGER.log_scalar_dict(val_log)
    LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
                f"{gathered_scores}")
    model.train()
    return vqa_results