Example #1
0
File: itm.py Project: zmykevin/UC2
def validate(model, val_loader):
    if hvd.rank() == 0:
        pbar = tqdm(total=len(val_loader))
    else:
        pbar = NoOp()
    LOGGER.info("start running Image Retrieval validation ...")
    model.eval()
    n_ex = 0
    st = time()

    recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0
    for batch in val_loader:
        scores = model(batch, compute_loss=False)
        _, indices = scores.squeeze(1).topk(10, dim=0)
        rank = (indices == 0).nonzero()
        if rank.numel():
            rank = rank.item()
            if rank < 1:
                recall_at_1 += 1
            if rank < 5:
                recall_at_5 += 1
            if rank < 10:
                recall_at_10 += 1
        n_ex += 1
        pbar.update(1)
    n_ex = sum(all_gather_list(n_ex))
    recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex
    recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex
    recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex
    tot_time = time() - st
    val_log = {
        'valid/ex_per_s': n_ex / tot_time,
        'valid/recall_1': recall_at_1,
        'valid/recall_5': recall_at_5,
        'valid/recall_10': recall_at_10
    }
    model.train()
    LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
                f"recall_1: {recall_at_1*100:.2f}, "
                f"recall_5: {recall_at_5*100:.2f}, "
                f"recall_10: {recall_at_10*100:.2f}")
    pbar.close()
    return val_log
Example #2
0
File: itm.py Project: zmykevin/UC2
def inference(model, eval_loader):
    model.eval()
    if hvd.rank() == 0:
        pbar = tqdm(total=len(eval_loader))
    else:
        pbar = NoOp()
    score_matrix = torch.zeros(len(eval_loader.dataset),
                               len(eval_loader.dataset.all_img_ids),
                               device=torch.device("cuda"),
                               dtype=torch.float16)
    for i, mini_batches in enumerate(eval_loader):
        j = 0
        for batch in mini_batches:
            scores = model(batch, compute_loss=False)
            bs = scores.size(0)
            score_matrix.data[i, j:j + bs] = scores.data.squeeze(1).half()
            j += bs
        assert j == score_matrix.size(1)
        pbar.update(1)
    model.train()
    pbar.close()
    return score_matrix
Example #3
0
File: itm.py Project: zmykevin/UC2
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'ckpt'), exist_ok=True)

        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_test'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_train'), exist_ok=True)
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        if "itm_coco_zh" not in txt_path:
            img_db = all_img_dbs[img_path]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset(txt_db, img_db, opts.negative_size))
        else:
            img_train_db = all_img_dbs[img_path[0]]
            img_val_db = all_img_dbs[img_path[1]]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset_COCO_CN(txt_db, img_train_db, img_val_db,
                                           opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # hard negative
    # hn_datasets = []
    # for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
    #     img_db = all_img_dbs[img_path]
    #     txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
    #     hn_datasets.append(ItmHardNegDataset(txt_db, img_db,
    #                                          opts.inf_minibatch_size))
    # hn_dataset = ConcatDataset(hn_datasets)
    # hn_dataloader = build_dataloader(hn_dataset, itm_hn_collate, False, opts)
    # hard_neg_dir = f'{opts.output_dir}/results_train/'

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db[0]]
    val_txt_db = TxtTokLmdb(opts.val_txt_db[0], -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)

    eval_loader_list = []
    assert len(opts.test_img_db) == len(opts.test_txt_db)
    for test_img_db_path, test_txt_db_path in zip(opts.test_img_db,
                                                  opts.test_txt_db):
        if "itm_coco_zh" not in test_txt_db_path:
            test_img_db = all_img_dbs[test_img_db_path]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                               opts.inf_minibatch_size)
        else:
            test_img_train_db = all_img_dbs[test_img_db_path[0]]
            test_img_val_db = all_img_dbs[test_img_db_path[1]]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset_COCO_CN(test_txt_db,
                                                       test_img_train_db,
                                                       test_img_val_db,
                                                       opts.inf_minibatch_size)
        eval_loader_test = build_dataloader(eval_dataset_test,
                                            itm_eval_collate, False, opts)
        eval_loader_list.append(eval_loader_test)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    #Rename the key if specified
    if opts.rename_checkpoints:
        rename_checkpoint(checkpoint)

    model = VLXLMRForImageTextRetrieval.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        load_embedding_only=opts.load_embedding_only,
        load_layer=opts.load_layer,
        img_dim=IMG_DIM,
        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    if opts.separate_lr:
        optimizer = build_xlmr_optimizer(model, opts)
    else:
        optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    #global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    if opts.steps_per_hard_neg != -1:
        compute_hard_neg(model, hn_dataloader, train_dataset,
                         opts.hard_neg_pool_size, hard_neg_dir)

    #Initialize the TrainingRestorer
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER._global_step = global_step
    if hvd.rank() != 0:
        restorer = NoOp()  #Added for Restoring the Checkpoints

    if global_step > 0:
        pbar.update(global_step)

    n_examples = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset,
                                            xlmr_itm_rank_collate, True, opts)
        for step, batch in enumerate(train_dataloader):
            #print(batch['input_ids'])
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            # print("run the loss")
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                if opts.separate_lr:
                    #added by Mingyang
                    xlmr_lr_this_step = get_xlmr_lr_sched(global_step, opts)
                    for i, param_group in enumerate(optimizer.param_groups):
                        if i < 2:
                            param_group['lr'] = xlmr_lr_this_step
                        else:
                            param_group['lr'] = lr_this_step
                    TB_LOGGER.add_scalar('xlmr_lr', xlmr_lr_this_step,
                                         global_step)
                else:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step

                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0 and global_step > 0:
                    # if global_step > 7000:
                    if opts.full_val:
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        #Log the information
                        # LOGGER.info(
                        #         f"========================= {split} ===========================\n"
                        #         f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
                        #         f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
                        #         f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
                        #         f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
                        #         f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
                        #         f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
                        # LOGGER.info("=========================================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)

                    model_saver.save(model, global_step)
                restorer.step()
                if (opts.steps_per_hard_neg != -1
                        and global_step % opts.steps_per_hard_neg == 0):
                    # sample hard negatives for training
                    compute_hard_neg(model, hn_dataloader, train_dataset,
                                     opts.hard_neg_pool_size, hard_neg_dir)
                    # break to reconstruct loader
                    # for potential multi-worker issue (not sure)
                    break

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        # NOTE can no longer count epochs

    pbar.close()
    # final validation
    # val_log = validate(model, val_dataloader)
    # TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f'{global_step}_final')

    for i, loader in enumerate(eval_loader_list):
        split = "test_{}".format(i)
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Example #4
0
File: itm.py Project: zmykevin/UC2
def get_hard_negs(model, loader, hard_negative_num=20):
    LOGGER.info("start running hard negative extraction")
    st = time()
    if hvd.rank() == 0:
        pbar = tqdm(total=len(loader))
    else:
        pbar = NoOp()
    model.eval()

    txt2hardimgs = {}
    img_to_score_txts = defaultdict(list)
    for batch in loader:
        scores = model(batch, compute_loss=False).squeeze(-1)
        txt = batch['gt_txt_id']
        imgs = batch['neg_img_ids']
        # record hard images
        hard_indices = scores.topk(hard_negative_num, sorted=False)[1].tolist()
        txt2hardimgs[txt] = [imgs[i] for i in hard_indices]
        # record img2txts
        for i, img in enumerate(imgs):
            img_to_score_txts[img].append((scores[i].item(), txt))
        pbar.update(1)
    pbar.close()

    LOGGER.info("start computing hard texts from images...")
    n_less_neg = 0
    tot_text = 0
    img2hardtxts = {}
    # need to gather hard texts from all GPUs
    all_img_ids = [
        i for dset in loader.dataset.datasets for i in dset.all_img_ids
    ]
    all_img_ids = any_broadcast(all_img_ids, 0)
    for img in all_img_ids:
        score_txts = img_to_score_txts[img]
        scores, txts = map(
            list,
            unzip(pair for pairs in all_gather_list(score_txts)
                  for pair in pairs))
        if hvd.rank() != 0:
            # only rank 0 needs to compute
            continue
        tot_text += len(txts)
        if len(txts) < hard_negative_num:
            # not enough negatives
            hard_indices = range(len(txts))
            n_less_neg += 1
        else:
            hard_indices = torch.tensor(scores).topk(hard_negative_num,
                                                     sorted=False)[1].tolist()
        img2hardtxts[img] = [txts[i] for i in hard_indices]

    n_less_neg = sum(all_gather_list(n_less_neg))
    if n_less_neg:
        LOGGER.info(f"Warning: {n_less_neg} images did not "
                    f"sample enough negatives")
    LOGGER.info(f"hard negative extraction finished "
                f"in {int(time() - st)} seconds "
                f"({tot_text//len(img_to_score_txts)} texts per images)")

    model.train()
    return txt2hardimgs, img2hardtxts
Example #5
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'))
        os.makedirs(join(opts.output_dir, 'results_test'))
        os.makedirs(join(opts.output_dir, 'results_train'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(
            ItmRankDataset(txt_db, img_db, opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        state_dict=checkpoint,
                                                        img_dim=IMG_DIM,
                                                        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset, itm_rank_collate,
                                            True, opts)
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'------------Step {global_step}-------------')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'-------------------------------------------')

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")

    pbar.close()
    if opts.num_train_steps % opts.valid_steps != 0:
        # final validation
        val_log = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)

    # evaluation
    for split, loader in [('val', eval_loader_val),
                          ('test', eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Example #6
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, "log"))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, "ckpt"))
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, "results_val"))
        os.makedirs(join(opts.output_dir, "results_test"))
        os.makedirs(join(opts.output_dir, "results_train"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(
        opts.train_img_dbs), "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets_t = []
    train_datasets_i = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets_t.append(
            ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size))
        train_datasets_i.append(
            ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size))
    train_dataset_t = ConcatDataset(train_datasets_t)
    train_dataset_i = ConcatDataset(train_datasets_i)
    train_dataloader_t = build_dataloader(train_dataset_t, itm_rank_hn_collate,
                                          True, opts)
    train_dataloader_i = build_dataloader(train_dataset_i, itm_rank_hn_collate,
                                          True, opts)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrievalHardNeg.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        img_dim=IMG_DIM,
        margin=opts.margin,
        hard_size=opts.hard_neg_size,
    )
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level="O2")

    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d",
                sum(all_gather_list(len(train_dataset_t))))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter("loss")
    model.train()

    global_step = 0
    step = 0
    n_examples = 0
    n_hard_ex = 0
    start = time()
    train_iter_i = iter(train_dataloader_i)
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for batch in train_dataloader_t:

            # hard text from image
            try:
                batch_i = next(train_iter_i)
            except StopIteration:
                train_iter_i = iter(train_dataloader_i)
                batch_i = next(train_iter_i)
            n_examples += batch_i["attn_masks"].size(0)
            loss = model(batch_i, sample_from="i", compute_loss=True)
            n_hard_ex += loss.numel()
            loss = loss.mean() / opts.train_batch_size
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=True) as scaled_loss:
                scaled_loss.backward()

            # hard image from text
            n_examples += batch["attn_masks"].size(0)
            loss = model(batch, sample_from="t", compute_loss=True)
            n_hard_ex += loss.numel()
            # NOTE we use gradient accumulation to implemented train_batch_size
            loss = loss.mean() / opts.train_batch_size

            step += 1
            delay_unscale = step % opts.train_batch_size != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if step % opts.train_batch_size == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                TB_LOGGER.add_scalar("lr", lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar("loss", running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar("grad_norm", grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f"------------Step {global_step}-------------")
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    tot_hn = sum(all_gather_list(n_hard_ex))
                    hn_per_sec = int(tot_hn / (time() - start))
                    LOGGER.info(f"{tot_ex} ({tot_hn}) examples (hard) "
                                f"trained at {ex_per_sec} ({hn_per_sec}) ex/s")
                    TB_LOGGER.add_scalar("perf/ex_per_s", ex_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar("perf/hn_per_s", hn_per_sec,
                                         global_step)
                    LOGGER.info(f"-------------------------------------------")

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break

    pbar.close()
    # final validation
    val_log = validate(model, val_dataloader)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f"{global_step}_final")

    # evaluation
    for split, loader in [("val", eval_loader_val),
                          ("test", eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")