Esempio n. 1
0
def create_dataloaders(datasets, is_train, opts, all_img_dbs=None):
    # opts.conf_th : 0.2
    # opts.min_bb : 10
    # opts.num_bb 36
    if all_img_dbs is None:
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                     opts.num_bb, opts.compressed_db)
    dataloaders = {}
    for dset in datasets:
        if is_train:
            assert len(dset['db']) == len(dset['img'])
            assert len(dset['tasks']) == len(dset['mix_ratio'])
            img_db = [all_img_dbs[path] for path in dset['img']]
        else:
            assert len(dset['db']) == len(dset['img']) == 1
            img_db = all_img_dbs[dset['img'][0]]

        for i, t in enumerate(dset['tasks']):
            task = f'{t}_{dset["name"]}'

            if is_train:
                LOGGER.info(f"Loading {task} train dataset "
                            f"{dset['db']}, {[img.img_dir for img in img_db]}")
                txt_db = [
                    TxtTokLmdb(path, opts.max_txt_len) for path in dset['db']
                ]
            else:
                LOGGER.info(f"Loading {task} validation dataset, "
                            f"{dset['db']}, {img_db.img_dir}")
                txt_db = TxtTokLmdb(dset['db'][0], -1)

            if task.startswith('mlm'):
                dataset = build_mlm_dataset(txt_db, img_db, is_train, opts)
            elif task.startswith('mrfr'):
                dataset = build_mrfr_dataset(txt_db, img_db, is_train, opts)
            elif task.startswith('mrc'):
                dataset = build_mrc_dataset(txt_db, img_db, is_train, opts)
            elif task.startswith('itm'):
                dataset = build_itm_dataset(txt_db, img_db, is_train, opts)
            else:
                raise ValueError(f'Undefined task {task}')

            LOGGER.info(f"{len(dataset[0])*hvd.size()} samples loaded")
            if task.startswith('itm'):
                # itm handles distributed training in dset not sampler
                loader = build_dataloader_itm(*dataset, is_train, opts)
            else:
                loader = build_dataloader(*dataset, is_train, opts)
            if is_train:
                ratio = dset['mix_ratio'][i]
                dataloaders[task] = (loader, ratio)
            else:
                dataloaders[task] = PrefetchLoader(loader)
    return dataloaders, all_img_dbs
Esempio n. 2
0
def main(opts):
    hvd.init()
    device = torch.device("cuda")  # support single GPU only
    train_opts = Struct(json.load(open(f'{opts.train_dir}/log/hps.json')))

    if 'paired' in train_opts.model:
        EvalDatasetCls = Nlvr2PairedEvalDataset
        eval_collate_fn = nlvr2_paired_eval_collate
        if train_opts.model == 'paired':
            ModelCls = UniterForNlvr2Paired
        elif train_opts.model == 'paired-attn':
            ModelCls = UniterForNlvr2PairedAttn
        else:
            raise ValueError('unrecognized model type')
    elif train_opts.model == 'triplet':
        EvalDatasetCls = Nlvr2TripletEvalDataset
        ModelCls = UniterForNlvr2Triplet
        eval_collate_fn = nlvr2_triplet_eval_collate
    else:
        raise ValueError('unrecognized model type')

    img_db = DetectFeatLmdb(opts.img_db, train_opts.conf_th, train_opts.max_bb,
                            train_opts.min_bb, train_opts.num_bb,
                            opts.compressed_db)
    txt_db = TxtTokLmdb(opts.txt_db, -1)
    dset = EvalDatasetCls(txt_db, img_db, train_opts.use_img_type)
    batch_size = (train_opts.val_batch_size
                  if opts.batch_size is None else opts.batch_size)
    sampler = TokenBucketSampler(dset.lens,
                                 bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size,
                                 droplast=False)
    eval_dataloader = DataLoader(dset,
                                 batch_sampler=sampler,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 collate_fn=eval_collate_fn)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    # Prepare model
    ckpt_file = f'{opts.train_dir}/ckpt/model_step_{opts.ckpt}.pt'
    checkpoint = torch.load(ckpt_file)
    model_config = UniterConfig.from_json_file(
        f'{opts.train_dir}/log/model.json')
    model = ModelCls(model_config, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.load_state_dict(checkpoint, strict=False)
    model.to(device)
    model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')

    results = evaluate(model, eval_dataloader, device)
    # write results
    if not exists(opts.output_dir):
        os.makedirs(opts.output_dir)
    with open(f'{opts.output_dir}/results.csv', 'w') as f:
        for id_, ans in results:
            f.write(f'{id_},{ans}\n')
    print(f'all results written')
Esempio n. 3
0
def create_dataloader(img_path, txt_path, batch_size, is_train,
                      dset_cls, collate_fn, opts):
    img_db = DetectFeatLmdb(img_path, opts.conf_th, opts.max_bb, opts.min_bb,
                            opts.num_bb, opts.compressed_db)
    txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1)
    dset = dset_cls(txt_db, img_db, opts.use_img_type)
    sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE,
                                 batch_size=batch_size, droplast=is_train)
    loader = DataLoader(dset, batch_sampler=sampler,
                        num_workers=opts.n_workers, pin_memory=opts.pin_mem,
                        collate_fn=collate_fn)
    return PrefetchLoader(loader)
Esempio n. 4
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    print('fasfafs: ', n_gpu)
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.train_config is not None:
        train_opts = Struct(json.load(open(opts.train_config)))
        opts.conf_th = train_opts.conf_th
        opts.max_bb = train_opts.max_bb
        opts.min_bb = train_opts.min_bb
        opts.num_bb = train_opts.num_bb

    # load DBs and image dirs
    eval_img_db = DetectFeatLmdb(opts.img_db, opts.conf_th, opts.max_bb,
                                 opts.min_bb, opts.num_bb, opts.compressed_db)
    eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
    eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size)

    # Prepare model
    checkpoint = torch.load(opts.checkpoint)
    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        checkpoint,
                                                        img_dim=IMG_DIM)
    if 'rank_output' not in checkpoint:
        model.init_output()  # zero shot setting

    model.to(device)
    model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')

    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=1,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 collate_fn=itm_eval_collate)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    eval_log, results = evaluate(model, eval_dataloader)
Esempio n. 5
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.train_config is not None:
        train_opts = Struct(json.load(open(opts.train_config)))
        opts.conf_th = train_opts.conf_th
        opts.max_bb = train_opts.max_bb
        opts.min_bb = train_opts.min_bb
        opts.num_bb = train_opts.num_bb

    # load DBs and image dirs
    eval_img_db = DetectFeatLmdb(opts.img_db, opts.conf_th, opts.max_bb,
                                 opts.min_bb, opts.num_bb, opts.compressed_db)
    eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
    eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size)

    # Prepare model
    checkpoint = torch.load(opts.checkpoint)
    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        checkpoint,
                                                        img_dim=IMG_DIM)
    if 'rank_output' not in checkpoint:
        model.init_output()  # zero shot setting

    model.to(device)
    model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')

    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=1,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 collate_fn=itm_eval_collate)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    eval_log, results = evaluate(model, eval_dataloader)
    if hvd.rank() == 0:
        if not exists(opts.output_dir) and rank == 0:
            os.makedirs(opts.output_dir)
        with open(f'{opts.output_dir}/config.json', 'w') as f:
            json.dump(vars(opts), f)
        with open(f'{opts.output_dir}/results.bin', 'wb') as f:
            pickle.dump(results, f)
        with open(f'{opts.output_dir}/scores.json', 'w') as f:
            json.dump(eval_log, f)
        LOGGER.info(f'evaluation finished')
        LOGGER.info(
            f"======================== Results =========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
        LOGGER.info("========================================================")
Esempio n. 6
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    hps_file = f"{opts.output_dir}/log/hps.json"
    model_opts = Struct(json.load(open(hps_file)))

    # train_examples = None
    ans2label_file = f"{opts.output_dir}/ckpt/ans2label.json"
    ans2label = json.load(open(ans2label_file))
    label2ans = {label: ans for ans, label in ans2label.items()}

    # load DBs and image dirs
    eval_img_db = DetectFeatLmdb(
        opts.img_db,
        model_opts.conf_th,
        model_opts.max_bb,
        model_opts.min_bb,
        model_opts.num_bb,
        opts.compressed_db,
    )
    eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
    eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db)

    # Prepare model
    if exists(opts.checkpoint):
        ckpt_file = opts.checkpoint
    else:
        ckpt_file = f"{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt"
    checkpoint = torch.load(ckpt_file)
    model = UniterForVisualQuestionAnswering.from_pretrained(
        f"{opts.output_dir}/log/model.json",
        checkpoint,
        img_dim=IMG_DIM,
        num_answer=len(ans2label),
    )
    model.to(device)
    if opts.fp16:
        model = amp.initialize(model, enabled=True, opt_level="O2")

    sampler = TokenBucketSampler(
        eval_dataset.lens,
        bucket_size=BUCKET_SIZE,
        batch_size=opts.batch_size,
        droplast=False,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_sampler=sampler,
        num_workers=opts.n_workers,
        pin_memory=opts.pin_mem,
        collate_fn=vqa_eval_collate,
    )
    eval_dataloader = PrefetchLoader(eval_dataloader)

    val_log, results, logits = evaluate(model, eval_dataloader, label2ans,
                                        opts.save_logits)
    result_dir = f"{opts.output_dir}/results_test"
    if not exists(result_dir) and rank == 0:
        os.makedirs(result_dir)

    all_results = list(concat(all_gather_list(results)))
    if opts.save_logits:
        all_logits = {}
        for id2logit in all_gather_list(logits):
            all_logits.update(id2logit)
    if hvd.rank() == 0:
        with open(f"{result_dir}/"
                  f"results_{opts.checkpoint}_all.json", "w") as f:
            json.dump(all_results, f)
        if opts.save_logits:
            np.savez(f"{result_dir}/logits_{opts.checkpoint}_all.npz",
                     **all_logits)
Esempio n. 7
0
File: itm.py Progetto: zmykevin/UC2
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'ckpt'), exist_ok=True)

        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_test'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_train'), exist_ok=True)
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        if "itm_coco_zh" not in txt_path:
            img_db = all_img_dbs[img_path]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset(txt_db, img_db, opts.negative_size))
        else:
            img_train_db = all_img_dbs[img_path[0]]
            img_val_db = all_img_dbs[img_path[1]]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset_COCO_CN(txt_db, img_train_db, img_val_db,
                                           opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # hard negative
    # hn_datasets = []
    # for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
    #     img_db = all_img_dbs[img_path]
    #     txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
    #     hn_datasets.append(ItmHardNegDataset(txt_db, img_db,
    #                                          opts.inf_minibatch_size))
    # hn_dataset = ConcatDataset(hn_datasets)
    # hn_dataloader = build_dataloader(hn_dataset, itm_hn_collate, False, opts)
    # hard_neg_dir = f'{opts.output_dir}/results_train/'

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db[0]]
    val_txt_db = TxtTokLmdb(opts.val_txt_db[0], -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)

    eval_loader_list = []
    assert len(opts.test_img_db) == len(opts.test_txt_db)
    for test_img_db_path, test_txt_db_path in zip(opts.test_img_db,
                                                  opts.test_txt_db):
        if "itm_coco_zh" not in test_txt_db_path:
            test_img_db = all_img_dbs[test_img_db_path]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                               opts.inf_minibatch_size)
        else:
            test_img_train_db = all_img_dbs[test_img_db_path[0]]
            test_img_val_db = all_img_dbs[test_img_db_path[1]]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset_COCO_CN(test_txt_db,
                                                       test_img_train_db,
                                                       test_img_val_db,
                                                       opts.inf_minibatch_size)
        eval_loader_test = build_dataloader(eval_dataset_test,
                                            itm_eval_collate, False, opts)
        eval_loader_list.append(eval_loader_test)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    #Rename the key if specified
    if opts.rename_checkpoints:
        rename_checkpoint(checkpoint)

    model = VLXLMRForImageTextRetrieval.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        load_embedding_only=opts.load_embedding_only,
        load_layer=opts.load_layer,
        img_dim=IMG_DIM,
        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    if opts.separate_lr:
        optimizer = build_xlmr_optimizer(model, opts)
    else:
        optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    #global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    if opts.steps_per_hard_neg != -1:
        compute_hard_neg(model, hn_dataloader, train_dataset,
                         opts.hard_neg_pool_size, hard_neg_dir)

    #Initialize the TrainingRestorer
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER._global_step = global_step
    if hvd.rank() != 0:
        restorer = NoOp()  #Added for Restoring the Checkpoints

    if global_step > 0:
        pbar.update(global_step)

    n_examples = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset,
                                            xlmr_itm_rank_collate, True, opts)
        for step, batch in enumerate(train_dataloader):
            #print(batch['input_ids'])
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            # print("run the loss")
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                if opts.separate_lr:
                    #added by Mingyang
                    xlmr_lr_this_step = get_xlmr_lr_sched(global_step, opts)
                    for i, param_group in enumerate(optimizer.param_groups):
                        if i < 2:
                            param_group['lr'] = xlmr_lr_this_step
                        else:
                            param_group['lr'] = lr_this_step
                    TB_LOGGER.add_scalar('xlmr_lr', xlmr_lr_this_step,
                                         global_step)
                else:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step

                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0 and global_step > 0:
                    # if global_step > 7000:
                    if opts.full_val:
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        #Log the information
                        # LOGGER.info(
                        #         f"========================= {split} ===========================\n"
                        #         f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
                        #         f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
                        #         f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
                        #         f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
                        #         f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
                        #         f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
                        # LOGGER.info("=========================================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)

                    model_saver.save(model, global_step)
                restorer.step()
                if (opts.steps_per_hard_neg != -1
                        and global_step % opts.steps_per_hard_neg == 0):
                    # sample hard negatives for training
                    compute_hard_neg(model, hn_dataloader, train_dataset,
                                     opts.hard_neg_pool_size, hard_neg_dir)
                    # break to reconstruct loader
                    # for potential multi-worker issue (not sure)
                    break

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        # NOTE can no longer count epochs

    pbar.close()
    # final validation
    # val_log = validate(model, val_dataloader)
    # TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f'{global_step}_final')

    for i, loader in enumerate(eval_loader_list):
        split = "test_{}".format(i)
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Esempio n. 8
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'))
        os.makedirs(join(opts.output_dir, 'results_test'))
        os.makedirs(join(opts.output_dir, 'results_train'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(
            ItmRankDataset(txt_db, img_db, opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        state_dict=checkpoint,
                                                        img_dim=IMG_DIM,
                                                        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset, itm_rank_collate,
                                            True, opts)
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'------------Step {global_step}-------------')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'-------------------------------------------')

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")

    pbar.close()
    if opts.num_train_steps % opts.valid_steps != 0:
        # final validation
        val_log = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)

    # evaluation
    for split, loader in [('val', eval_loader_val),
                          ('test', eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Esempio n. 9
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hasattr(opts, 'ans2label_path'):
        ans2label = json.load(open(opts.ans2label_path, 'r', encoding='utf-8'))
    else:
        ans2label = json.load(
            open(f'{dirname(abspath(__file__))}'
                 f'/utils/ans2label.json'))
    label2ans = {label: ans for ans, label in ans2label.items()}

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts)
    # val
    LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db)
    val_dataloader = build_dataloader(val_dataset, vqa_eval_collate, False,
                                      opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint, map_location=device)
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualQuestionAnswering.from_pretrained(
        opts.model_config,
        checkpoint,
        img_dim=IMG_DIM,
        num_answer=len(ans2label))
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        json.dump(ans2label,
                  open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean() * batch['targets'].size(1)  # instance-leval bce
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(model, val_dataloader,
                                                label2ans)
                    with open(
                            f'{opts.output_dir}/results/'
                            f'results_{global_step}_'
                            f'rank{rank}.json', 'w') as f:
                        json.dump(results, f)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        val_log, results = validate(model, val_dataloader, label2ans)
        with open(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.json', 'w') as f:
            json.dump(results, f)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)
def main(opts, checkpoint_dir=None, tuning=False):
    from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
    with logger.catch(reraise=True):
        logger.info(f"{opts}")
        if isinstance(opts, dict):
            opts = edict(opts)

        hvd.init()
        n_gpu = hvd.size()
        device = torch.device("cuda", hvd.local_rank())
        torch.cuda.set_device(hvd.local_rank())
        rank = hvd.rank()
        opts.rank = rank
        LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                    "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                                  opts.fp16))

        if opts.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, "
                "should be >= 1".format(opts.gradient_accumulation_steps))

        set_random_seed(opts.seed)
        """
        # load DBs and image dirs
        """
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                     opts.num_bb, opts.compressed_db)

        # val
        LOGGER.info(
            f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
        val_img_db = all_img_dbs[opts.val_img_db]
        val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
        val_dataset = MemeEvalDataset(1, val_txt_db, val_img_db)
        val_dataloader = build_dataloader(val_dataset, meme_eval_collate,
                                          False, opts)
        val_itm_dataloader = build_dataloader(val_dataset,
                                              meme_eval_itm_ot_collate, False,
                                              opts)

        test_img_db = val_img_db
        test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
        test_dataset = MemeEvalDataset(1, test_txt_db, test_img_db)
        test_dataloader = build_dataloader(test_dataset, meme_eval_collate,
                                           False, opts)
        """
        # Prepare model
        """
        if opts.checkpoint:
            logger.info(f"Load checkpoint: {opts.checkpoint}")
            checkpoint = torch.load(opts.checkpoint)
        else:
            checkpoint = {}

        all_dbs = opts.train_txt_dbs + [opts.val_txt_db]

        model = UniterForITM.from_pretrained(opts.model_config,
                                             checkpoint,
                                             img_dim=IMG_DIM,
                                             num_answer=1)
        model.to(device)

        if hasattr(opts, 'tune_checkpoint') and isinstance(
                model, UniterForITM):
            model_state = torch.load(opts.tune_checkpoint)[0]
            model.load_state_dict(model_state)
        # make sure every process has same model parameters in the beginning
        broadcast_tensors([p.data for p in model.parameters()], 0)
        set_dropout(model, opts.dropout)
        """
        # Prepare optimizer
        """
        optimizer = build_optimizer(model, opts)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          enabled=opts.fp16,
                                          opt_level='O2')
        global_step = 0

        LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
        LOGGER.info("  Num examples = %d", len(val_dataset) * hvd.size())
        LOGGER.info("  Batch size = %d", opts.train_batch_size)
        LOGGER.info("  Accumulate steps = %d",
                    opts.gradient_accumulation_steps)
        LOGGER.info("  Num steps = %d", opts.num_train_steps)

        model.eval()

        val_log, results = validate(model, val_dataloader, None)

        with open(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.json', 'w') as f:
            json.dump(results, f)
        pd.DataFrame.from_dict(results).to_csv(
            f'{opts.output_dir}/results/'
            f'results_{global_step}_'
            f'rank{rank}.csv',
            index=False)

        test_log, results = test(model, test_dataloader, None)

        os.makedirs(f'{opts.output_dir}/results/', exist_ok=True)
        with open(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'test.json', 'w') as f:
            json.dump(results, f)

        test_csv = pd.DataFrame.from_dict(results)[['id', 'proba', 'label']]
        test_csv = reorder_csv_rows(
            os.path.join(HERE, 'asset', 'test_unseen.jsonl'),
            test_csv,
        )
        test_csv.to_csv(f'{opts.output_dir}/' f'test.csv', index=False)
        output_path = (f'{opts.output_dir}/' f'test.csv')
        print('Save test predict to: ', output_path)
        if opts.checkpoint:
            try:
                shutil.copy(opts.checkpoint,
                            os.path.join(opts.output_dir, 'final.pt'))
            except shutil.SameFileError:
                logger.info(
                    'Rerun of the same chekcpoint, not re-copy it as final.pt')
Esempio n. 11
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, "log"))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, "ckpt"))
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, "results_val"))
        os.makedirs(join(opts.output_dir, "results_test"))
        os.makedirs(join(opts.output_dir, "results_train"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(
        opts.train_img_dbs), "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets_t = []
    train_datasets_i = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets_t.append(
            ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size))
        train_datasets_i.append(
            ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size))
    train_dataset_t = ConcatDataset(train_datasets_t)
    train_dataset_i = ConcatDataset(train_datasets_i)
    train_dataloader_t = build_dataloader(train_dataset_t, itm_rank_hn_collate,
                                          True, opts)
    train_dataloader_i = build_dataloader(train_dataset_i, itm_rank_hn_collate,
                                          True, opts)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrievalHardNeg.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        img_dim=IMG_DIM,
        margin=opts.margin,
        hard_size=opts.hard_neg_size,
    )
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level="O2")

    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d",
                sum(all_gather_list(len(train_dataset_t))))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter("loss")
    model.train()

    global_step = 0
    step = 0
    n_examples = 0
    n_hard_ex = 0
    start = time()
    train_iter_i = iter(train_dataloader_i)
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for batch in train_dataloader_t:

            # hard text from image
            try:
                batch_i = next(train_iter_i)
            except StopIteration:
                train_iter_i = iter(train_dataloader_i)
                batch_i = next(train_iter_i)
            n_examples += batch_i["attn_masks"].size(0)
            loss = model(batch_i, sample_from="i", compute_loss=True)
            n_hard_ex += loss.numel()
            loss = loss.mean() / opts.train_batch_size
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=True) as scaled_loss:
                scaled_loss.backward()

            # hard image from text
            n_examples += batch["attn_masks"].size(0)
            loss = model(batch, sample_from="t", compute_loss=True)
            n_hard_ex += loss.numel()
            # NOTE we use gradient accumulation to implemented train_batch_size
            loss = loss.mean() / opts.train_batch_size

            step += 1
            delay_unscale = step % opts.train_batch_size != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if step % opts.train_batch_size == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                TB_LOGGER.add_scalar("lr", lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar("loss", running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar("grad_norm", grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f"------------Step {global_step}-------------")
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    tot_hn = sum(all_gather_list(n_hard_ex))
                    hn_per_sec = int(tot_hn / (time() - start))
                    LOGGER.info(f"{tot_ex} ({tot_hn}) examples (hard) "
                                f"trained at {ex_per_sec} ({hn_per_sec}) ex/s")
                    TB_LOGGER.add_scalar("perf/ex_per_s", ex_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar("perf/hn_per_s", hn_per_sec,
                                         global_step)
                    LOGGER.info(f"-------------------------------------------")

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break

    pbar.close()
    # final validation
    val_log = validate(model, val_dataloader)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f"{global_step}_final")

    # evaluation
    for split, loader in [("val", eval_loader_val),
                          ("test", eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
def main(opts, checkpoint_dir=None, tuning=False):
    from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
    with logger.catch(reraise=True):
        logger.info(f"{opts}")
        if isinstance(opts, dict):
            opts = edict(opts)

        hvd.init()
        n_gpu = hvd.size()
        device = torch.device("cuda", hvd.local_rank())
        torch.cuda.set_device(hvd.local_rank())
        rank = hvd.rank()
        opts.rank = rank
        LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                    "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                                  opts.fp16))

        if opts.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, "
                "should be >= 1".format(opts.gradient_accumulation_steps))

        set_random_seed(opts.seed)
        """
        # load DBs and image dirs
        """
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                     opts.num_bb, opts.compressed_db)

        # train
        LOGGER.info(f"Loading Train Dataset "
                    f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
        train_datasets = []
        for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
            img_db = all_img_dbs[img_path]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            train_datasets.append(MemeDataset(1, txt_db, img_db))
        train_dataset = ConcatDatasetWithLens(train_datasets)
        train_dataloader = build_dataloader(train_dataset, meme_collate, True,
                                            opts)

        # val
        LOGGER.info(
            f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
        val_img_db = all_img_dbs[opts.val_img_db]
        val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
        val_dataset = MemeEvalDataset(1, val_txt_db, val_img_db)
        val_dataloader = build_dataloader(val_dataset,
                                          meme_eval_itm_ot_collate, False,
                                          opts)

        # test_img_db = val_img_db
        # test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
        # test_dataset = MemeEvalDataset(1, test_txt_db, test_img_db)
        # test_dataloader = build_dataloader(test_dataset, meme_eval_collate,
        #                                 False, opts)
        """
        # Prepare model
        """
        if opts.checkpoint:
            checkpoint = torch.load(opts.checkpoint)
        else:
            checkpoint = {}

        all_dbs = opts.train_txt_dbs + [opts.val_txt_db]

        model = UniterForITM.from_pretrained(opts.model_config,
                                             checkpoint,
                                             img_dim=IMG_DIM,
                                             num_answer=1)
        model.to(device)
        # make sure every process has same model parameters in the beginning
        broadcast_tensors([p.data for p in model.parameters()], 0)
        set_dropout(model, opts.dropout)
        """
        # Prepare optimizer
        """
        optimizer = build_optimizer(model, opts)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          enabled=opts.fp16,
                                          opt_level='O2')
        global_step = 0
        if rank == 0:
            save_training_meta(opts)
            TB_LOGGER.create(join(opts.output_dir, 'log'))
            pbar = tqdm(total=opts.num_train_steps)
            model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
            # json.dump(ans2label,
            #           open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
            os.makedirs(join(opts.output_dir, 'results'),
                        exist_ok=tuning)  # store VQA predictions
            add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        else:
            LOGGER.disabled = True
            pbar = NoOp()
            model_saver = NoOp()

        LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
        LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
        LOGGER.info("  Batch size = %d", opts.train_batch_size)
        LOGGER.info("  Accumulate steps = %d",
                    opts.gradient_accumulation_steps)
        LOGGER.info("  Num steps = %d", opts.num_train_steps)

        running_loss = RunningMeter('loss')
        model.train()
        n_examples = 0
        n_epoch = 0

        if checkpoint_dir is not None and tuning:
            checkpoint = os.path.join(checkpoint_dir, "checkpoint")
            (model_state, optimizer_state, n_epoch,
             n_examples) = torch.load(checkpoint)
            model.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

            LOGGER.info(
                f"***** Resume from ray tune checkpoint : {checkpoint_dir} *****"
            )
            LOGGER.info("  n_examples = %d", n_examples)
            LOGGER.info("  n_epoch = %d", n_epoch)

            # shutil.rmtree(checkpoint_dir)

        start = time()
        # quick hack for amp delay_unscale bug
        optimizer.zero_grad()
        optimizer.step()
        while True:
            for step, batch in enumerate(train_dataloader):
                if global_step > 2000:
                    logger.error('Force stop at global step 2000')
                    sys.exit(0)
                n_examples += batch['input_ids'].size(0)

                if opts.adv_training:
                    # NOTE: reverse label like what we do in UniterForITM
                    targets = batch['targets']
                    targets = (targets > 0.5).long()
                    targets = torch.abs(targets - 1)
                    batch['targets'] = targets

                    # initialize delta
                    txt_embeds_init = model.uniter.embeddings.word_embeddings(
                        batch['input_ids'])
                    img_embeds_init = batch['img_feat']

                    # for simplicity, we initialize the delta as zero vectors, which performs
                    # very simliar as initializing randomly using norm or uniform distributions
                    txt_delta = torch.zeros_like(txt_embeds_init)
                    img_delta = torch.zeros_like(img_embeds_init)

                    # calculate the prob. scores for clean samples
                    gt_answer_scores = model(batch, compute_loss=False)
                    gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                    gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                    # the main loop
                    for astep in range(opts.adv_steps):
                        # (0) forward
                        if opts.adv_modality == ["text"]:
                            txt_delta.requires_grad_()
                            img_delta = torch.zeros_like(img_embeds_init)
                        elif opts.adv_modality == ["image"]:
                            img_delta.requires_grad_()
                            txt_delta = torch.zeros_like(txt_embeds_init)
                        else:
                            txt_delta.requires_grad_()
                            img_delta.requires_grad_()

                        if "alter" not in opts.adv_modality:
                            answer_scores = model(
                                batch,
                                adv_training=True,
                                adv_modality=opts.adv_modality,
                                adv_delta_txt=txt_delta,
                                adv_delta_img=img_delta,
                                compute_loss=False)

                            # CE loss
                            ce_loss = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            # KL loss
                            answer_prob = F.softmax(answer_scores, dim=1)
                            answer_logprob = F.log_softmax(answer_scores,
                                                           dim=1)
                            kl_loss = F.kl_div(
                                answer_logprob, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob,
                                    reduction='none')
                            kl_loss = kl_loss.mean()

                            # (1) backward
                            loss = (ce_loss + opts.adv_kl_weight *
                                    kl_loss) / opts.adv_steps
                        else:
                            answer_scores_1 = model(batch,
                                                    adv_training=True,
                                                    adv_modality=["text"],
                                                    adv_delta_txt=txt_delta,
                                                    adv_delta_img=None,
                                                    compute_loss=False)

                            # CE loss
                            ce_loss_1 = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            answer_scores_2 = model(batch,
                                                    adv_training=True,
                                                    adv_modality=["image"],
                                                    adv_delta_txt=None,
                                                    adv_delta_img=img_delta,
                                                    compute_loss=False)

                            # CE loss
                            ce_loss_2 = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            # KL loss
                            answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                            answer_logprob_1 = F.log_softmax(answer_scores_1,
                                                             dim=1)
                            answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                            answer_logprob_2 = F.log_softmax(answer_scores_2,
                                                             dim=1)
                            kl_loss_1 = F.kl_div(
                                answer_logprob_1, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob_1,
                                    reduction='none')
                            kl_loss_1 = kl_loss_1.mean()
                            kl_loss_2 = F.kl_div(
                                answer_logprob_2, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob_2,
                                    reduction='none')
                            kl_loss_2 = kl_loss_2.mean()

                            # (1) backward
                            loss = (
                                ce_loss_1 + ce_loss_2 + opts.adv_kl_weight *
                                (kl_loss_1 + kl_loss_2)) / (opts.adv_steps * 2)

                        delay_unscale = (
                            (step + 1) % opts.gradient_accumulation_steps !=
                            0) or ((astep + 1) % opts.adv_steps != 0)
                        with amp.scale_loss(
                                loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                            if not delay_unscale:
                                # gather gradients from every processes
                                # do this before unscaling
                                # to make sure every process uses
                                # the same gradient scale
                                grads = [
                                    p.grad.data for p in model.parameters()
                                    if p.requires_grad and p.grad is not None
                                ]
                                all_reduce_and_rescale_tensors(grads, float(1))

                        running_loss(loss.item())

                        if astep == opts.adv_steps - 1:
                            # further updates on delta
                            break

                        # (2) get gradient on delta
                        # fix fp16 problem
                        amp_scale = scaled_loss.item() // loss.item()
                        if "text" in opts.adv_modality:
                            txt_delta_grad = txt_delta.grad.clone().detach()
                            txt_delta_grad = txt_delta_grad.float() / amp_scale
                        if "image" in opts.adv_modality:
                            img_delta_grad = img_delta.grad.clone().detach()
                            img_delta_grad = img_delta_grad.float() / amp_scale

                        # (3) update and clip for txt delta
                        if "text" in opts.adv_modality:
                            if opts.norm_type == "l2":
                                denorm = torch.norm(txt_delta_grad.view(
                                    txt_delta_grad.size(0), -1),
                                                    dim=1).view(-1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                txt_delta_step = (opts.adv_lr_txt *
                                                  txt_delta_grad /
                                                  denorm).to(txt_delta)
                                txt_delta = (txt_delta +
                                             txt_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    delta_norm = torch.norm(txt_delta.view(
                                        txt_delta.size(0), -1),
                                                            p=2,
                                                            dim=1).detach()
                                    exceed_mask = (
                                        delta_norm >
                                        opts.adv_max_norm).to(txt_embeds_init)
                                    reweights = (opts.adv_max_norm /
                                                 delta_norm * exceed_mask +
                                                 (1 - exceed_mask)).view(
                                                     -1, 1, 1)
                                    txt_delta = (txt_delta *
                                                 reweights).detach()
                            elif opts.norm_type == "linf":
                                denorm = torch.norm(txt_delta_grad.view(
                                    txt_delta_grad.size(0), -1),
                                                    dim=1,
                                                    p=float("inf")).view(
                                                        -1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                txt_delta_step = (opts.adv_lr_txt *
                                                  txt_delta_grad /
                                                  denorm).to(txt_delta)
                                txt_delta = (txt_delta +
                                             txt_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    txt_delta = torch.clamp(
                                        txt_delta, -opts.adv_max_norm,
                                        opts.adv_max_norm).detach()

                        # (4) update and clip for image delta
                        if "image" in opts.adv_modality:
                            if opts.norm_type == "l2":
                                denorm = torch.norm(img_delta_grad.view(
                                    img_delta_grad.size(0), -1),
                                                    dim=1).view(-1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                img_delta_step = (opts.adv_lr_img *
                                                  img_delta_grad /
                                                  denorm).to(img_delta)
                                img_delta = (img_delta +
                                             img_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    delta_norm = torch.norm(img_delta.view(
                                        img_delta.size(0), -1),
                                                            p=2,
                                                            dim=1).detach()
                                    exceed_mask = (
                                        delta_norm >
                                        opts.adv_max_norm).to(img_embeds_init)
                                    reweights = (opts.adv_max_norm /
                                                 delta_norm * exceed_mask +
                                                 (1 - exceed_mask)).view(
                                                     -1, 1, 1)
                                    img_delta = (img_delta *
                                                 reweights).detach()
                            elif opts.norm_type == "linf":
                                denorm = torch.norm(img_delta_grad.view(
                                    img_delta_grad.size(0), -1),
                                                    dim=1,
                                                    p=float("inf")).view(
                                                        -1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                img_delta_step = (opts.adv_lr_img *
                                                  img_delta_grad /
                                                  denorm).to(img_delta)
                                img_delta = (img_delta +
                                             img_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    img_delta = torch.clamp(
                                        img_delta, -opts.adv_max_norm,
                                        opts.adv_max_norm).detach()
                else:
                    loss = model(batch, compute_loss=True)
                    loss = loss.mean()
                    delay_unscale = (step +
                                     1) % opts.gradient_accumulation_steps != 0
                    with amp.scale_loss(
                            loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
                        scaled_loss.backward()
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling to make sure every process uses
                            # the same gradient scale
                            grads = [
                                p.grad.data for p in model.parameters()
                                if p.requires_grad and p.grad is not None
                            ]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())
                """
                loss compute end
                log & step start
                """

                if (step + 1) % opts.gradient_accumulation_steps == 0:
                    global_step += 1

                    # learning rate scheduling
                    lr_this_step = get_lr_sched(global_step, opts)
                    for i, param_group in enumerate(optimizer.param_groups):
                        if i == 0 or i == 1:
                            param_group['lr'] = lr_this_step * opts.lr_mul
                        elif i == 2 or i == 3:
                            param_group['lr'] = lr_this_step
                        else:
                            raise ValueError()
                    TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                    # log loss
                    # NOTE: not gathered across GPUs for efficiency
                    TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                    TB_LOGGER.step()

                    # update model params
                    if opts.grad_norm != -1:
                        grad_norm = clip_grad_norm_(
                            amp.master_params(optimizer), opts.grad_norm)
                        TB_LOGGER.add_scalar('grad_norm', grad_norm,
                                             global_step)
                    optimizer.step()
                    optimizer.zero_grad()
                    pbar.update(1)

                    if global_step % 100 == 0:
                        # monitor training throughput
                        LOGGER.info(
                            f'============Step {global_step}=============')
                        tot_ex = sum(all_gather_list(n_examples))
                        ex_per_sec = int(tot_ex / (time() - start))
                        LOGGER.info(f'{tot_ex} examples trained at '
                                    f'{ex_per_sec} ex/s')
                        TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                             global_step)
                        LOGGER.info(
                            f'===========================================')

                    if global_step % opts.valid_steps == 0:
                        val_log, results = validate(model, val_dataloader,
                                                    None)

                        with open(
                                f'{opts.output_dir}/results/'
                                f'results_{global_step}_'
                                f'rank{rank}.json', 'w') as f:
                            json.dump(results, f)
                        pd.DataFrame.from_dict(results).to_csv(
                            f'{opts.output_dir}/results/'
                            f'results_{global_step}_'
                            f'rank{rank}.csv',
                            index=False)

                        # _, test_results = test(model, test_dataloader, global_step)
                        # pd.DataFrame.from_dict(test_results).to_csv(
                        #     f'{opts.output_dir}/results/'
                        #     f'test_{global_step}.csv',
                        #     index=False)

                        TB_LOGGER.log_scaler_dict(val_log)
                        model_saver.save(model, global_step)

                        if tuning:
                            with tune.checkpoint_dir(
                                    step=n_epoch) as checkpoint_dir:
                                logger.info(
                                    f'***** Save tune ckpt: {checkpoint_dir} *****'
                                )
                                path = os.path.join(checkpoint_dir,
                                                    "checkpoint")
                                torch.save((model.state_dict(),
                                            optimizer.state_dict(), n_epoch,
                                            n_examples), path)
                            tune.report(
                                loss=(val_log['valid/loss']),
                                accuracy=val_log['valid/acc'],
                                auroc=val_log['valid/auroc'],
                            )
                if global_step >= opts.num_train_steps:
                    break
            if global_step >= opts.num_train_steps:
                break
            n_epoch += 1
            LOGGER.info(f"finished {n_epoch} epochs")
            """
            END of training loop
            """

        if opts.num_train_steps % opts.valid_steps != 0:
            val_log, results = validate(model, val_dataloader, None)
            with open(
                    f'{opts.output_dir}/results/'
                    f'results_{global_step}_'
                    f'rank{rank}.json', 'w') as f:
                json.dump(results, f)
            pd.DataFrame.from_dict(results).to_csv(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.csv',
                index=False)
            TB_LOGGER.log_scaler_dict(val_log)
            model_saver.save(model, global_step)
Esempio n. 13
0
def create_dataloaders(datasets, is_train, opts, all_img_dbs=None):
    if all_img_dbs is None:
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                     opts.num_bb, opts.compressed_db)
    dataloaders = {}
    for dset in datasets:
        if is_train:
            assert len(dset['db']) == len(dset['img'])
            assert len(dset['tasks']) == len(dset['mix_ratio'])
            img_db = [all_img_dbs[path] for path in dset['img']]
        else:
            assert len(dset['db']) == len(dset['img']) == 1
            img_db = all_img_dbs[dset['img'][0]]

        for i, t in enumerate(dset['tasks']):
            task = f'{t}_{dset["name"]}'

            if is_train:
                LOGGER.info(f"Loading {task} train dataset "
                            f"{dset['db']}, {[img.img_dir for img in img_db]}")
                txt_db = [TxtTokLmdb(path, opts.max_txt_len)
                          for path in dset['db']]
                language_list = []
                #only get the language_list from 'cc'
                if dset['name'] == 'cc' and opts.multilingual_vmlm and task.startswith('vmlm'):
                    for path in dset['db']:
                        language = path.split('_')[-2] #Hacky Way to get the language, Need a better mechanism
                        language_list.append(language)

            else:
                LOGGER.info(f"Loading {task} validation dataset, "
                            f"{dset['db']}, {img_db.img_dir}")
                txt_db = TxtTokLmdb(dset['db'][0], -1)
                language_list = []
                if opts.multilingual_vmlm and task.startswith('vmlm'):
                    lan = dset["name"].split('_')[-1]
                    language_list.append(lan)
            
            if task.startswith('mlm'):
                blind = 'blind' in task
                dataset = build_mlm_dataset(txt_db, img_db,
                                            blind, is_train, opts)
            elif task.startswith('tlm'):
                blind = 'blind' in task
                text_only = "ni" in task
                dataset = build_tlm_dataset(txt_db, img_db,
                                            blind, is_train, opts, text_only)
            elif task.startswith('mmxlm'):
                if 'soft' in task:
                    soft = True
                else:
                    soft = False
                            
                dataset = build_mmxlm_dataset(txt_db, img_db, is_train, opts, soft)
            elif task.startswith('vmlm'):
                if 'soft' in task:
                    soft = True
                    #load the img_soft_label
                    assert dset.get('img_token_soft_label', None) is not None
                else:
                    soft = False
                if is_train:
                    if soft:
                        assert len(dset['db']) == len(dset['img_token_soft_label'])
                        img_token_sl_db = [Img_SoftLabel_Lmdb(path) for path in dset['img_token_soft_label']]
                    else:
                        img_token_sl_db = None
                else:
                    if soft:
                        assert len(dset['db']) == len(dset['img_token_soft_label']) == 1
                        img_token_sl_db = Img_SoftLabel_Lmdb(dset['img_token_soft_label'][0])
                    else:
                        img_token_sl_db = None
                        
                #print(language_list)
                dataset = build_vmlm_dataset(txt_db, img_db, img_token_sl_db, is_train, opts, soft, language_list=language_list)
            elif task.startswith('mrfr'):
                only_i = 'only_i' in task
                dataset = build_mrfr_dataset(txt_db, img_db,
                                             only_i, is_train, opts)
            elif task.startswith('mrm-nce'):
                only_i = 'only_i' in task
                dataset = build_mrm_nce_dataset(txt_db, img_db,
                                                only_i, is_train, opts)
            elif task.startswith('mrc'):
                only_i = 'only_i' in task
                dataset = build_mrc_dataset(txt_db, img_db,
                                            only_i, is_train, opts)
            elif task.startswith('itm'):
                dataset = build_itm_dataset(txt_db, img_db, is_train, opts)
            else:
                raise ValueError(f'Undefined task {task}')

            LOGGER.info(f"{len(dataset[0])*hvd.size()} samples loaded")
            if task.startswith('itm'):
                # itm handles distributed training in dset not sampler
                loader = build_dataloader_itm(*dataset, is_train, opts)
            else:
                loader = build_dataloader(*dataset, is_train, opts)
            if is_train:
                ratio = dset['mix_ratio'][i]
                dataloaders[task] = (loader, ratio)
            else:
                dataloaders[task] = PrefetchLoader(loader)
    return dataloaders, all_img_dbs
Esempio n. 14
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    ans2label = json.load(
        open(f'{dirname(abspath(__file__))}'
             f'/utils/ans2label.json'))
    label2ans = {label: ans for ans, label in ans2label.items()}

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts)
    # val
    LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db)
    val_dataloader = build_dataloader(val_dataset, vqa_eval_collate, False,
                                      opts)

    # Prepare model
    if opts.checkpoint:
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()}
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualQuestionAnswering.from_pretrained(
        opts.model_config,
        checkpoint,
        img_dim=IMG_DIM,
        num_answer=len(ans2label))
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        json.dump(ans2label,
                  open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            # ============================ Code for adversarial training =============
            if opts.adv_training:
                # initialize delta
                txt_embeds_init = model.uniter.embeddings.word_embeddings(
                    batch['input_ids'])
                img_embeds_init = batch['img_feat']

                # for simplicity, we initialize the delta as zero vectors, which performs
                # very simliar as initializing randomly using norm or uniform distributions
                txt_delta = torch.zeros_like(txt_embeds_init)
                img_delta = torch.zeros_like(img_embeds_init)

                # calculate the prob. scores for clean samples
                gt_answer_scores = model(batch, compute_loss=False)
                gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                # the main loop
                for astep in range(opts.adv_steps):
                    # (0) forward
                    if opts.adv_modality == ["text"]:
                        txt_delta.requires_grad_()
                        img_delta = torch.zeros_like(img_embeds_init)
                    elif opts.adv_modality == ["image"]:
                        img_delta.requires_grad_()
                        txt_delta = torch.zeros_like(txt_embeds_init)
                    else:
                        txt_delta.requires_grad_()
                        img_delta.requires_grad_()

                    if "alter" not in opts.adv_modality:
                        answer_scores = model(batch,
                                              adv_training=True,
                                              adv_modality=opts.adv_modality,
                                              adv_delta_txt=txt_delta,
                                              adv_delta_img=img_delta,
                                              compute_loss=False)

                        # BCE loss
                        bce_loss = F.binary_cross_entropy_with_logits(
                            answer_scores, batch['targets'], reduction='none')
                        bce_loss = bce_loss.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        # KL loss
                        answer_prob = F.softmax(answer_scores, dim=1)
                        answer_logprob = F.log_softmax(answer_scores, dim=1)
                        kl_loss = F.kl_div(answer_logprob,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob,reduction='none')
                        kl_loss = kl_loss.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        # (1) backward
                        loss = (bce_loss +
                                opts.adv_kl_weight * kl_loss) / opts.adv_steps
                    else:
                        answer_scores_1 = model(batch,
                                                adv_training=True,
                                                adv_modality=["text"],
                                                adv_delta_txt=txt_delta,
                                                adv_delta_img=None,
                                                compute_loss=False)

                        bce_loss_1 = F.binary_cross_entropy_with_logits(
                            answer_scores_1,
                            batch['targets'],
                            reduction='none')
                        bce_loss_1 = bce_loss_1.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        answer_scores_2 = model(batch,
                                                adv_training=True,
                                                adv_modality=["image"],
                                                adv_delta_txt=None,
                                                adv_delta_img=img_delta,
                                                compute_loss=False)

                        bce_loss_2 = F.binary_cross_entropy_with_logits(
                            answer_scores_2,
                            batch['targets'],
                            reduction='none')
                        bce_loss_2 = bce_loss_2.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        # KL loss
                        answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                        answer_logprob_1 = F.log_softmax(answer_scores_1,
                                                         dim=1)
                        answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                        answer_logprob_2 = F.log_softmax(answer_scores_2,
                                                         dim=1)

                        kl_loss_1 = F.kl_div(answer_logprob_1,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob_1,reduction='none')
                        kl_loss_1 = kl_loss_1.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        kl_loss_2 = F.kl_div(answer_logprob_2,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob_2,reduction='none')
                        kl_loss_2 = kl_loss_2.mean() * batch['targets'].size(
                            1)  # instance-leval bce

                        # (1) backward
                        loss = (bce_loss_1 + bce_loss_2 + opts.adv_kl_weight *
                                (kl_loss_1 + kl_loss_2)) / (opts.adv_steps * 2)

                    delay_unscale = (
                        (step + 1) % opts.gradient_accumulation_steps !=
                        0) or ((astep + 1) % opts.adv_steps != 0)
                    with amp.scale_loss(
                            loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling to make sure every process uses
                            # the same gradient scale
                            grads = [
                                p.grad.data for p in model.parameters()
                                if p.requires_grad and p.grad is not None
                            ]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())

                    if astep == opts.adv_steps - 1:
                        # further updates on delta
                        break

                    # (2) get gradient on delta
                    # fix fp16 problem
                    amp_scale = scaled_loss.item() // loss.item()
                    if "text" in opts.adv_modality:
                        txt_delta_grad = txt_delta.grad.clone().detach().float(
                        ) / amp_scale
                    if "image" in opts.adv_modality:
                        img_delta_grad = img_delta.grad.clone().detach().float(
                        ) / amp_scale

                    # (3) update and clip for txt delta
                    if "text" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(txt_delta.view(
                                    txt_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(txt_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                txt_delta = (txt_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                txt_delta = torch.clamp(
                                    txt_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

                    # (4) update and clip for image delta
                    if "image" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(img_delta.view(
                                    img_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(img_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                img_delta = (img_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                img_delta = torch.clamp(
                                    img_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

            else:
                loss = model(batch, compute_loss=True)
                loss = loss.mean() * batch['targets'].size(
                    1)  # instance-leval bce
                delay_unscale = (step +
                                 1) % opts.gradient_accumulation_steps != 0
                with amp.scale_loss(
                        loss, optimizer,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
                    if not delay_unscale:
                        # gather gradients from every processes
                        # do this before unscaling to make sure every process uses
                        # the same gradient scale
                        grads = [
                            p.grad.data for p in model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        all_reduce_and_rescale_tensors(grads, float(1))

                running_loss(loss.item())

            # ============================ End ==========================

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(model, val_dataloader,
                                                label2ans)
                    with open(
                            f'{opts.output_dir}/results/'
                            f'results_{global_step}_'
                            f'rank{rank}.json', 'w') as f:
                        json.dump(results, f)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        val_log, results = validate(model, val_dataloader, label2ans)
        with open(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.json', 'w') as f:
            json.dump(results, f)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)