Exemple #1
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))
    if rank != 0:
        LOGGER.disabled = True

    hps_file = f'{opts.output_dir}/log/hps.json'
    model_opts = Struct(json.load(open(hps_file)))

    assert opts.split in opts.img_db and opts.split in opts.txt_db
    # load DBs and image dirs
    eval_img_db, eval_img_db_gt = load_img_feat(opts.img_db, model_opts)
    eval_txt_db = VcrTxtTokLmdb(opts.txt_db, -1)
    eval_dataset = VcrEvalDataset(
        "test", eval_txt_db, img_db=eval_img_db,
        img_db_gt=eval_img_db_gt)

    # Prepare model
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        f'{opts.output_dir}/log/model.json', state_dict={},
        img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if exists(opts.checkpoint):
        ckpt_file = opts.checkpoint
    else:
        ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
    checkpoint = torch.load(ckpt_file)
    state_dict = checkpoint.get('model_state', checkpoint)
    matched_state_dict = {}
    unexpected_keys = set()
    missing_keys = set()
    for name, param in model.named_parameters():
        missing_keys.add(name)
    for key, data in state_dict.items():
        if key in missing_keys:
            matched_state_dict[key] = data
            missing_keys.remove(key)
        else:
            unexpected_keys.add(key)
    LOGGER.info(f"Unexpected_keys: {list(unexpected_keys)}")
    LOGGER.info(f"Missing_keys: {list(missing_keys)}")
    model.load_state_dict(matched_state_dict, strict=False)
    model.to(device)
    if opts.fp16:
        model = amp.initialize(model, enabled=True, opt_level='O2')

    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=opts.batch_size,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 shuffle=False,
                                 collate_fn=vcr_eval_collate)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    _, results = evaluate(model, eval_dataloader)
    result_dir = f'{opts.output_dir}/results_{opts.split}'
    if not exists(result_dir) and rank == 0:
        os.makedirs(result_dir)

    all_results = {}
    for id2res in all_gather_list(results):
        all_results.update(id2res)
    if hvd.rank() == 0:
        with open(f'{result_dir}/'
                  f'results_{opts.checkpoint}_all.json', 'w') as f:
            json.dump(all_results, f)
        probs_df = save_for_submission(
            f'{result_dir}/results_{opts.checkpoint}_all.json')
        probs_df.to_csv(f'{result_dir}/results_{opts.checkpoint}_all.csv')
Exemple #2
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))
    if rank != 0:
        LOGGER.disabled = True

    hps_file = f'{opts.output_dir}/log/hps.json'
    model_opts = Struct(json.load(open(hps_file)))

    assert opts.split in opts.img_db and opts.split in opts.txt_db
    # load DBs and image dirs
    eval_img_db, eval_img_db_gt = load_img_feat(opts.img_db, model_opts)
    eval_txt_db = VcrTxtTokLmdb(opts.txt_db, -1)
    eval_dataset = VcrEvalDataset("val",
                                  eval_txt_db,
                                  img_db=eval_img_db,
                                  img_db_gt=eval_img_db_gt)

    # Prepare model
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        f'{opts.output_dir}/log/model.json', state_dict={}, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_type_embedding_know()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if exists(opts.checkpoint):
        ckpt_file = opts.checkpoint
    else:
        ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
    checkpoint = torch.load(ckpt_file)
    state_dict = checkpoint.get('model_state', checkpoint)
    matched_state_dict = {}
    unexpected_keys = set()
    missing_keys = set()
    for name, param in model.named_parameters():
        missing_keys.add(name)
    for key, data in state_dict.items():
        if key in missing_keys:
            matched_state_dict[key] = data
            missing_keys.remove(key)
        else:
            unexpected_keys.add(key)
    LOGGER.info(f"Unexpected_keys: {list(unexpected_keys)}")
    LOGGER.info(f"Missing_keys: {list(missing_keys)}")
    model.load_state_dict(matched_state_dict, strict=False)
    model.to(device)
    if opts.fp16:
        model = amp.initialize(model, enabled=True, opt_level='O2')

    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=opts.batch_size,
                                 num_workers=opts.n_workers,
                                 pin_memory=opts.pin_mem,
                                 shuffle=False,
                                 collate_fn=vcr_eval_collate)
    eval_dataloader = PrefetchLoader(eval_dataloader)

    results = evaluate(model, eval_dataloader)

    output = '/src/vlkaf.json'
    before_json = ""
    for i, item in enumerate(results):
        jstring = json.dumps(item)
        before_json += jstring + '\n'

    f = open(output, "w")
    f.write(before_json)
    f.close()
    '''
Exemple #3
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db, img_db_gt = load_img_feat(img_path, all_img_dbs, opts)
        qa_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qa")
        qar_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qar")
        train_datasets.append(
            VcrDataset(qa_txt_db, img_db_gt=img_db_gt, img_db=img_db))
        train_datasets.append(
            VcrDataset(qar_txt_db, img_db_gt=img_db_gt, img_db=img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vcr_collate, True, opts)
    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db, val_img_db_gt = load_img_feat(opts.val_img_db, all_img_dbs, opts)
    val_txt_db = VcrTxtTokLmdb(opts.val_txt_db, -1, task="qa")
    val_dataset = VcrEvalDataset(
        "val", val_txt_db, img_db=val_img_db, img_db_gt=val_img_db_gt)
    val_final_dataset = VcrEvalDataset(
        ##"test"
        "val", val_txt_db, img_db=val_img_db, img_db_gt=val_img_db_gt)
    val_dataloader = build_dataloader(val_dataset, vcr_eval_collate,
                                      False, opts)
    val_final_dataloader = build_dataloader(
        val_final_dataset, vcr_eval_collate,
        False, opts)

    # Prepare model
    if opts.checkpoint and opts.checkpoint_from == "pretrain":
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        opts.model_config, checkpoint, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_type_embedding_know()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if opts.checkpoint_from == "vcr_pretrain":
        checkpoint = torch.load(opts.checkpoint)
        state_dict = checkpoint.get('model_state', checkpoint)
        matched_state_dict = {}
        unexpected_keys = set()
        missing_keys = set()
        for name, param in model.named_parameters():
            missing_keys.add(name)
        for key, data in state_dict.items():
            if key in missing_keys:
                matched_state_dict[key] = data
                missing_keys.remove(key)
            else:
                unexpected_keys.add(key)
        print("Unexpected_keys:", list(unexpected_keys))
        print("Missing_keys:", list(missing_keys))
        model.load_state_dict(matched_state_dict, strict=False)
    del checkpoint
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model, optimizer,
                                      enabled=opts.fp16, opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
                                ) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [p.grad.data for p in model.parameters()
                             if p.requires_grad and p.grad is not None]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time()-start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s',
                                         ex_per_sec, global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(
                        model, val_dataloader)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if global_step % opts.valid_steps != 0:
        val_log, results = validate(
            model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
    val_log, results = validate(model, val_final_dataloader)
    with open(f'{opts.output_dir}/results/'
              f'results_{global_step}_final_qa_qar_'
              f'rank{rank}.json', 'w') as f:
        json.dump(results, f)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, global_step)
Exemple #4
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db, img_db_gt = load_img_feat(img_path, all_img_dbs, opts)
        qa_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qa")
        qar_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qar")
        train_datasets.append(
            VcrDataset(qa_txt_db, img_db_gt=img_db_gt, img_db=img_db))
        train_datasets.append(
            VcrDataset(qar_txt_db, img_db_gt=img_db_gt, img_db=img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vcr_collate, True, opts)
    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db, val_img_db_gt = load_img_feat(opts.val_img_db, all_img_dbs,
                                              opts)
    val_txt_db = VcrTxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = VcrEvalDataset("val",
                                 val_txt_db,
                                 img_db=val_img_db,
                                 img_db_gt=val_img_db_gt)
    val_final_dataset = VcrEvalDataset("test",
                                       val_txt_db,
                                       img_db=val_img_db,
                                       img_db_gt=val_img_db_gt)
    val_dataloader = build_dataloader(val_dataset, vcr_eval_collate, False,
                                      opts)
    val_final_dataloader = build_dataloader(val_final_dataset,
                                            vcr_eval_collate, False, opts)

    # Prepare model
    if opts.checkpoint and opts.checkpoint_from == "pretrain":
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()}
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        opts.model_config, checkpoint, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if opts.checkpoint_from == "vcr_pretrain":
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()}
        state_dict = checkpoint.get('model_state', checkpoint)
        matched_state_dict = {}
        unexpected_keys = set()
        missing_keys = set()
        for name, param in model.named_parameters():
            missing_keys.add(name)
        for key, data in state_dict.items():
            if key in missing_keys:
                matched_state_dict[key] = data
                missing_keys.remove(key)
            else:
                unexpected_keys.add(key)
        print("Unexpected_keys:", list(unexpected_keys))
        print("Missing_keys:", list(missing_keys))
        model.load_state_dict(matched_state_dict, strict=False)
    del checkpoint
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            # ============= Code for adversarial training =============
            if opts.adv_training:
                # initialize delta
                txt_embeds_init = model.uniter.embeddings.word_embeddings(
                    batch['input_ids'])
                img_embeds_init = batch['img_feat']

                # for simplicity, we initialize the delta as zero vectors, which performs
                # very simliar as initializing randomly using norm or uniform distributions
                txt_delta = torch.zeros_like(txt_embeds_init)
                img_delta = torch.zeros_like(img_embeds_init)

                # calculate the prob. scores for clean samples
                gt_answer_scores = model(batch, compute_loss=False)
                gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                # the main loop
                for astep in range(opts.adv_steps):
                    # (0) forward
                    if opts.adv_modality == ["text"]:
                        txt_delta.requires_grad_()
                        img_delta = torch.zeros_like(img_embeds_init)
                    elif opts.adv_modality == ["image"]:
                        img_delta.requires_grad_()
                        txt_delta = torch.zeros_like(txt_embeds_init)
                    else:
                        txt_delta.requires_grad_()
                        img_delta.requires_grad_()

                    if "alter" not in opts.adv_modality:
                        answer_scores = model(batch,
                                              adv_training=True,
                                              adv_modality=opts.adv_modality,
                                              adv_delta_txt=txt_delta,
                                              adv_delta_img=img_delta,
                                              compute_loss=False)

                        # CE loss
                        ce_loss = F.cross_entropy(answer_scores,
                                                  batch['targets'].squeeze(-1),
                                                  reduction='mean')

                        # KL loss
                        answer_prob = F.softmax(answer_scores, dim=1)
                        answer_logprob = F.log_softmax(answer_scores, dim=1)
                        kl_loss = F.kl_div(
                            answer_logprob, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob,
                                reduction='none')
                        kl_loss = kl_loss.mean()

                        # (1) backward
                        loss = (ce_loss +
                                opts.adv_kl_weight * kl_loss) / opts.adv_steps
                    else:
                        answer_scores_1 = model(batch,
                                                adv_training=True,
                                                adv_modality=["text"],
                                                adv_delta_txt=txt_delta,
                                                adv_delta_img=None,
                                                compute_loss=False)

                        # CE loss
                        ce_loss_1 = F.cross_entropy(
                            answer_scores,
                            batch['targets'].squeeze(-1),
                            reduction='mean')

                        answer_scores_2 = model(batch,
                                                adv_training=True,
                                                adv_modality=["image"],
                                                adv_delta_txt=None,
                                                adv_delta_img=img_delta,
                                                compute_loss=False)

                        # CE loss
                        ce_loss_2 = F.cross_entropy(
                            answer_scores,
                            batch['targets'].squeeze(-1),
                            reduction='mean')

                        # KL loss
                        answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                        answer_logprob_1 = F.log_softmax(answer_scores_1,
                                                         dim=1)
                        answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                        answer_logprob_2 = F.log_softmax(answer_scores_2,
                                                         dim=1)
                        kl_loss_1 = F.kl_div(
                            answer_logprob_1, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob_1,
                                reduction='none')
                        kl_loss_1 = kl_loss_1.mean()
                        kl_loss_2 = F.kl_div(
                            answer_logprob_2, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob_2,
                                reduction='none')
                        kl_loss_2 = kl_loss_2.mean()

                        # (1) backward
                        loss = (ce_loss_1 + ce_loss_2 + opts.adv_kl_weight *
                                (kl_loss_1 + kl_loss_2)) / (opts.adv_steps * 2)

                    delay_unscale = (
                        (step + 1) % opts.gradient_accumulation_steps !=
                        0) or ((astep + 1) % opts.adv_steps != 0)
                    with amp.scale_loss(
                            loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling
                            # to make sure every process uses
                            # the same gradient scale
                            grads = [
                                p.grad.data for p in model.parameters()
                                if p.requires_grad and p.grad is not None
                            ]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())

                    if astep == opts.adv_steps - 1:
                        # further updates on delta
                        break

                    # (2) get gradient on delta
                    # fix fp16 problem
                    amp_scale = scaled_loss.item() // loss.item()
                    if "text" in opts.adv_modality:
                        txt_delta_grad = txt_delta.grad.clone().detach()
                        txt_delta_grad = txt_delta_grad.float() / amp_scale
                    if "image" in opts.adv_modality:
                        img_delta_grad = img_delta.grad.clone().detach()
                        img_delta_grad = img_delta_grad.float() / amp_scale

                    # (3) update and clip for txt delta
                    if "text" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(txt_delta.view(
                                    txt_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(txt_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                txt_delta = (txt_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                txt_delta = torch.clamp(
                                    txt_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

                    # (4) update and clip for image delta
                    if "image" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(img_delta.view(
                                    img_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(img_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                img_delta = (img_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                img_delta = torch.clamp(
                                    img_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

            else:
                loss = model(batch, compute_loss=True)
                loss = loss.mean()
                delay_unscale = ((step + 1) % opts.gradient_accumulation_steps
                                 != 0)
                with amp.scale_loss(
                        loss, optimizer,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
                    if not delay_unscale:
                        # gather gradients from every processes
                        # do this before unscaling to make sure every process uses
                        # the same gradient scale
                        grads = [
                            p.grad.data for p in model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        all_reduce_and_rescale_tensors(grads, float(1))

                running_loss(loss.item())

            # ============================ End ==========================

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info('===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(model, val_dataloader)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if global_step % opts.valid_steps != 0:
        val_log, results = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
    val_log, results = validate(model, val_final_dataloader)
    with open(
            f'{opts.output_dir}/results/'
            f'results_{global_step}_final_qa_qar_'
            f'rank{rank}.json', 'w') as f:
        json.dump(results, f)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, global_step)