Exemplo n.º 1
0
Arquivo: itm.py Projeto: zmykevin/UC2
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'ckpt'), exist_ok=True)

        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_test'), exist_ok=True)
        os.makedirs(join(opts.output_dir, 'results_train'), exist_ok=True)
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        if "itm_coco_zh" not in txt_path:
            img_db = all_img_dbs[img_path]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset(txt_db, img_db, opts.negative_size))
        else:
            img_train_db = all_img_dbs[img_path[0]]
            img_val_db = all_img_dbs[img_path[1]]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            if opts.hard_neg_size > 0:
                train_datasets.append(
                    ItmRankDatasetHardNeg(txt_db, img_db, opts.negative_size,
                                          opts.hard_neg_size))
            else:
                train_datasets.append(
                    ItmRankDataset_COCO_CN(txt_db, img_train_db, img_val_db,
                                           opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # hard negative
    # hn_datasets = []
    # for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
    #     img_db = all_img_dbs[img_path]
    #     txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
    #     hn_datasets.append(ItmHardNegDataset(txt_db, img_db,
    #                                          opts.inf_minibatch_size))
    # hn_dataset = ConcatDataset(hn_datasets)
    # hn_dataloader = build_dataloader(hn_dataset, itm_hn_collate, False, opts)
    # hard_neg_dir = f'{opts.output_dir}/results_train/'

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db[0]]
    val_txt_db = TxtTokLmdb(opts.val_txt_db[0], -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)

    eval_loader_list = []
    assert len(opts.test_img_db) == len(opts.test_txt_db)
    for test_img_db_path, test_txt_db_path in zip(opts.test_img_db,
                                                  opts.test_txt_db):
        if "itm_coco_zh" not in test_txt_db_path:
            test_img_db = all_img_dbs[test_img_db_path]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                               opts.inf_minibatch_size)
        else:
            test_img_train_db = all_img_dbs[test_img_db_path[0]]
            test_img_val_db = all_img_dbs[test_img_db_path[1]]
            test_txt_db = TxtTokLmdb(test_txt_db_path, -1)
            eval_dataset_test = ItmEvalDataset_COCO_CN(test_txt_db,
                                                       test_img_train_db,
                                                       test_img_val_db,
                                                       opts.inf_minibatch_size)
        eval_loader_test = build_dataloader(eval_dataset_test,
                                            itm_eval_collate, False, opts)
        eval_loader_list.append(eval_loader_test)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    #Rename the key if specified
    if opts.rename_checkpoints:
        rename_checkpoint(checkpoint)

    model = VLXLMRForImageTextRetrieval.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        load_embedding_only=opts.load_embedding_only,
        load_layer=opts.load_layer,
        img_dim=IMG_DIM,
        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    if opts.separate_lr:
        optimizer = build_xlmr_optimizer(model, opts)
    else:
        optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    #global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    if opts.steps_per_hard_neg != -1:
        compute_hard_neg(model, hn_dataloader, train_dataset,
                         opts.hard_neg_pool_size, hard_neg_dir)

    #Initialize the TrainingRestorer
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER._global_step = global_step
    if hvd.rank() != 0:
        restorer = NoOp()  #Added for Restoring the Checkpoints

    if global_step > 0:
        pbar.update(global_step)

    n_examples = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset,
                                            xlmr_itm_rank_collate, True, opts)
        for step, batch in enumerate(train_dataloader):
            #print(batch['input_ids'])
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            # print("run the loss")
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                if opts.separate_lr:
                    #added by Mingyang
                    xlmr_lr_this_step = get_xlmr_lr_sched(global_step, opts)
                    for i, param_group in enumerate(optimizer.param_groups):
                        if i < 2:
                            param_group['lr'] = xlmr_lr_this_step
                        else:
                            param_group['lr'] = lr_this_step
                    TB_LOGGER.add_scalar('xlmr_lr', xlmr_lr_this_step,
                                         global_step)
                else:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step

                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0 and global_step > 0:
                    # if global_step > 7000:
                    if opts.full_val:
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        #Log the information
                        # LOGGER.info(
                        #         f"========================= {split} ===========================\n"
                        #         f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
                        #         f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
                        #         f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
                        #         f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
                        #         f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
                        #         f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
                        # LOGGER.info("=========================================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)

                    model_saver.save(model, global_step)
                restorer.step()
                if (opts.steps_per_hard_neg != -1
                        and global_step % opts.steps_per_hard_neg == 0):
                    # sample hard negatives for training
                    compute_hard_neg(model, hn_dataloader, train_dataset,
                                     opts.hard_neg_pool_size, hard_neg_dir)
                    # break to reconstruct loader
                    # for potential multi-worker issue (not sure)
                    break

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        # NOTE can no longer count epochs

    pbar.close()
    # final validation
    # val_log = validate(model, val_dataloader)
    # TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f'{global_step}_final')

    for i, loader in enumerate(eval_loader_list):
        split = "test_{}".format(i)
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    opts.n_gpu = n_gpu
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))
    if hvd.rank() != 0:
        LOGGER.disabled = True
    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading the whole video dataset {opts.sub_txt_db}, "
                f"{opts.vfeat_db}")
    video_db = load_video_sub_dataset(opts.vfeat_db, opts.sub_txt_db,
                                      opts.vfeat_interval, opts)

    # data loaders
    # train
    LOGGER.info(f"Loading the train QA dataset {opts.train_query_txt_db}")
    video_ids = get_video_ids(opts.train_query_txt_db)
    train_q_txt_db = QaQueryTokLmdb(opts.train_query_txt_db, opts.max_txt_len)
    train_dataloaders = build_downstream_dataloaders([opts.task],
                                                     video_db,
                                                     video_ids,
                                                     True,
                                                     opts,
                                                     q_txt_db=train_q_txt_db,
                                                     shuffle=True)
    meta_loader = MetaLoader(train_dataloaders,
                             accum_steps=opts.gradient_accumulation_steps,
                             distributed=n_gpu > 1)
    meta_loader = PrefetchLoader(meta_loader)

    # val
    LOGGER.info(f"Loading the val QA dataset {opts.val_query_txt_db}")
    video_ids = get_video_ids(opts.val_query_txt_db)
    val_q_txt_db = QaQueryTokLmdb(opts.val_query_txt_db, -1)
    val_dataloaders = build_downstream_dataloaders([opts.task],
                                                   video_db,
                                                   video_ids,
                                                   False,
                                                   opts,
                                                   q_txt_db=val_q_txt_db)
    if opts.test_query_txt_db:
        LOGGER.info(f"Loading the test QA dataset {opts.test_query_txt_db}")
        video_ids = get_video_ids(opts.test_query_txt_db)
        test_q_txt_db = QaQueryTokLmdb(opts.test_query_txt_db, -1)
        test_dataloaders = build_downstream_dataloaders([opts.task],
                                                        video_db,
                                                        video_ids,
                                                        False,
                                                        opts,
                                                        q_txt_db=test_q_txt_db)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\
        ".position_embeddings.weight"
    if img_pos_embed_weight_key in checkpoint:
        max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key])
    else:
        max_frm_seq_len = MAX_FRM_SEQ_LEN

    model = HeroForVideoQA.from_pretrained(opts.model_config,
                                           state_dict=checkpoint,
                                           vfeat_dim=VFEAT_DIM,
                                           max_frm_seq_len=max_frm_seq_len)

    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.fp16,
                                      opt_level='O2')
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        if not exists(join(opts.output_dir, 'results')):
            # store tvqa predictions
            os.makedirs(join(opts.output_dir, 'results'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    task2loss = {
        task: RunningMeter(f'loss/{task}')
        for task in train_dataloaders.keys()
    }

    for obj in (f'{opts.task}_qa', f'{opts.task}_st_ed'):
        task2loss[obj] = RunningMeter(f'loss/{obj}')

    model.train()
    n_examples = defaultdict(int)
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    if global_step == 0:
        optimizer.step()
    for step, (task, batch) in enumerate(meta_loader):
        n_examples[task] += opts.train_batch_size

        loss = model(batch, task=task, compute_loss=True)

        loss_qa, loss_st_ed = loss
        loss = loss_qa + opts.lw_st_ed * loss_st_ed
        for n, ls in (('st_ed', loss_st_ed), ('qa', loss_qa)):
            ls = ls.item()
            task2loss[f'{task}_{n}'](ls)

        loss = loss.mean()
        task2loss[task](loss.item())

        delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
        with amp.scale_loss(loss,
                            optimizer,
                            delay_unscale=delay_unscale,
                            loss_id=task2scaler[task]) as scaled_loss:
            scaled_loss.backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

        if (step + 1) % opts.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            lr_this_step = get_lr_sched(global_step, opts)
            for i, param_group in enumerate(optimizer.param_groups):
                if i == 0 or i == 1:
                    param_group['lr'] = lr_this_step * opts.lr_mul
                elif i == 2 or i == 3:
                    param_group['lr'] = lr_this_step
                else:
                    raise ValueError()
            TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

            TB_LOGGER.log_scaler_dict({
                temp_loss.name: temp_loss.val
                for temp_loss in task2loss.values()
                if temp_loss.val is not None
            })
            TB_LOGGER.step()

            # update model params
            if opts.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            opts.grad_norm)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            optimizer.step()
            optimizer.zero_grad()
            restorer.step()
            pbar.update(1)

            if global_step % 100 == 0:
                # monitor training throughput
                LOGGER.info('-------------------------------------------')
                LOGGER.info(f'Step {global_step}:')
                for t in train_dataloaders.keys():
                    tot_ex = sum(all_gather_list(n_examples[t]))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{t}: {tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
                                         global_step)

            if global_step % opts.valid_steps == 0:
                LOGGER.info('===========================================')
                LOGGER.info(f"Step {global_step}: start running validation")
                validate(model,
                         val_dataloaders,
                         "val",
                         opts,
                         global_step=global_step)
                if opts.test_query_txt_db:
                    validate(model,
                             test_dataloaders,
                             "test",
                             opts,
                             global_step=global_step)
                LOGGER.info('===========================================')
                model_saver.save(model, global_step)
        if global_step >= opts.num_train_steps:
            break

    LOGGER.info('===========================================')
    if global_step % opts.valid_steps != 0:
        LOGGER.info('===========================================')
        LOGGER.info(f"Step {global_step}: start running validation")
        validate(model, val_dataloaders, "val", opts, global_step=global_step)
        if opts.test_query_txt_db:
            validate(model,
                     test_dataloaders,
                     "test",
                     opts,
                     global_step=global_step)
        LOGGER.info('===========================================')
    model_saver.save(model, f'{global_step}_final')
Exemplo n.º 3
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
                f"{opts.train_img_dir}")
    if 'paired' in opts.model:
        DatasetCls = Nlvr2PairedDataset
        EvalDatasetCls = Nlvr2PairedEvalDataset
        collate_fn = nlvr2_paired_collate
        eval_collate_fn = nlvr2_paired_eval_collate
        if opts.model == 'paired':
            ModelCls = UniterForNlvr2Paired
        elif opts.model == 'paired-attn':
            ModelCls = UniterForNlvr2PairedAttn
        else:
            raise ValueError('unrecognized model type')
    elif opts.model == 'triplet':
        DatasetCls = Nlvr2TripletDataset
        EvalDatasetCls = Nlvr2TripletEvalDataset
        ModelCls = UniterForNlvr2Triplet
        collate_fn = nlvr2_triplet_collate
        eval_collate_fn = nlvr2_triplet_eval_collate
    else:
        raise ValueError('unrecognized model type')

    # data loaders
    train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db,
                                         opts.train_batch_size, True,
                                         DatasetCls, collate_fn, opts)
    val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db,
                                       opts.val_batch_size, False,
                                       EvalDatasetCls, eval_collate_fn, opts)
    test_dataloader = create_dataloader(opts.test_img_db, opts.test_txt_db,
                                        opts.val_batch_size, False,
                                        EvalDatasetCls, eval_collate_fn, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = ModelCls.from_pretrained(opts.model_config,
                                     state_dict=checkpoint,
                                     img_dim=IMG_DIM)
    model.init_type_embedding()
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataloader.dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            targets = batch['targets']
            n_examples += targets.size(0)

            loss = model(**batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'Step {global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    for split, loader in [('val', val_dataloader),
                                          ('test', test_dataloader)]:
                        LOGGER.info(f"Step {global_step}: start running "
                                    f"validation on {split} split...")
                        log, results = validate(model, loader, split)
                        with open(
                                f'{opts.output_dir}/results/'
                                f'{split}_results_{global_step}_'
                                f'rank{rank}.csv', 'w') as f:
                            for id_, ans in results:
                                f.write(f'{id_},{ans}\n')
                        TB_LOGGER.log_scaler_dict(log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
    for split, loader in [('val', val_dataloader), ('test', test_dataloader)]:
        LOGGER.info(f"Step {global_step}: start running "
                    f"validation on {split} split...")
        log, results = validate(model, loader, split)
        with open(
                f'{opts.output_dir}/results/'
                f'{split}_results_{global_step}_'
                f'rank{rank}_final.csv', 'w') as f:
            for id_, ans in results:
                f.write(f'{id_},{ans}\n')
        TB_LOGGER.log_scaler_dict(log)
    model_saver.save(model, f'{global_step}_final')
Exemplo n.º 4
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))
    
    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets]
               for dset in datasets for db in dset['db']]

    tokenizer = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    #print(tokenizer)
    # assert all(tokenizer == json.load(open(f'{db}/meta.json'))['bert']
    #            for db in all_dbs)

    # build data loaders
    train_dataloaders, all_img_dbs = create_dataloaders(
        opts.train_datasets, True, opts)
    val_dataloaders, _ = create_dataloaders(
        opts.val_datasets, False, opts, all_img_dbs)
    meta_loader = MetaLoader(train_dataloaders,
                             accum_steps=opts.gradient_accumulation_steps,
                             distributed=n_gpu > 1)
    meta_loader = PrefetchLoader(meta_loader)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    if opts.rename_checkpoints:
        rename_checkpoint(checkpoint)
    #Include early_adaptation
    if opts.early_adaptation:
        early_adaptation_checkpoint = torch.load(opts.early_adaptation_checkpoint)
        checkpoint['roberta.img_embeddings.img_linear.weight'] = early_adaptation_checkpoint['v2w_linear.weight']
        checkpoint['roberta.img_embeddings.img_linear.bias'] = early_adaptation_checkpoint['v2w_linear.bias']
    
    model = VLXLMRForPretraining.from_pretrained(
        opts.model_config, checkpoint,
        img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM,
        nce_temp=opts.nce_temp, ot_pos_only=opts.ot_pos_only)

    # model = UniterForPretraining.from_pretrained(
    #     opts.model_config, checkpoint,
    #     img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM,
    #     nce_temp=opts.nce_temp, ot_pos_only=opts.ot_pos_only)

    model.pad_vocab()  # tensor core padding for vocabulary
    model.to(device)
    model.train()
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
    model, optimizer = amp.initialize(model, optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.fp16, opt_level='O2')

    #global_step = 0
    #Initialize the TrainingRestorer
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER._global_step = global_step
    if hvd.rank() !=0:
        restorer = NoOp() #Added for Restoring the Checkpoints

    if global_step > 0:
        pbar.update(global_step)

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)
    
    # to compute training statistics
    task2loss = {task: RunningMeter(f'loss/{task}')
                 for task in train_dataloaders.keys()}
    # ITM w/ OT
    if opts.itm_ot_lambda > 0:
        for task in train_dataloaders.keys():
            if task.startswith('itm'):
                task2loss[f'{task}_xe'] = RunningMeter(f'loss/{task}_xe')
                task2loss[f'{task}_ot'] = RunningMeter(f'loss/{task}_ot')
                if not opts.ot_pos_only:
                    task2loss[f'{task}_ot_pos'] = RunningMeter(
                        f'loss/{task}_ot_pos')
                    task2loss[f'{task}_ot_neg'] = RunningMeter(
                        f'loss/{task}_ot_neg')
    
    n_examples = defaultdict(int)
    n_in_units = defaultdict(int)
    n_loss_units = defaultdict(int)
    n_neg_nce = defaultdict(int)
    grad_norm = 0

    start = time()
    #Added by Mingyang to debug the training procedure
    # debug_start = torch.cuda.Event(enable_timing=True)
    # debug_end = torch.cuda.Event(enable_timing=True)

    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    #Added by Mingyang Zhou
    # debug_start.record()
    for step, (name, batch) in enumerate(meta_loader):

        # forward pass
        assert all(name == n for n in all_gather_list(name))
        n_examples[name] += batch['input_ids'].size(0)
        n_in_units[name] += (batch['attn_masks'] == 1).sum().item()
        if 'nce' in name:
            n_neg_nce[name] += batch['neg_feats'].size(0)
        task = name.split('_')[0]
        loss = model(batch, task=task, compute_loss=True)
        if task.startswith('itm'):
            # OT
            itm_loss, ot_loss = loss
            n_loss_units[name] += itm_loss.size(0)
            itm_loss = itm_loss.mean()
            if ot_loss is not None:
                if not opts.ot_pos_only:
                    ot_pos, ot_neg = ot_loss
                    ot_loss = (ot_pos.sum() - ot_neg.sum()
                               ) / (ot_pos.size(0) + ot_neg.size(0))

                    # NOTE: be ware of empty tensor
                    ot_pos = ot_pos.mean().item()
                    if not math.isnan(ot_pos):
                        task2loss[f'{name}_ot_pos'](ot_pos)
                    ot_neg = ot_neg.mean().item()
                    if not math.isnan(ot_neg):
                        task2loss[f'{name}_ot_neg'](ot_neg)
                else:
                    ot_loss = ot_loss.mean()
                loss = itm_loss + opts.itm_ot_lambda * ot_loss
                task2loss[f'{name}_xe'](itm_loss.item())
                task2loss[f'{name}_ot'](ot_loss.item())
            else:
                loss = itm_loss
        elif task.startswith('vmlm-soft'):
            loss = 1000*loss.mean()
        else:
            n_loss_units[name] += loss.size(0)
            loss = loss.mean()  # loss is not normalized in model

        # backward pass
        delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
                            loss_id=task2scaler[name]) as scaled_loss:
            scaled_loss.backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [p.grad.data for p in model.parameters()
                         if p.requires_grad and p.grad is not None]
                all_reduce_and_rescale_tensors(grads, float(1))
        task2loss[name](loss.item())

        # optimizer update and logging
        if (step + 1) % opts.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            lr_this_step = get_lr_sched(global_step, opts)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

            # log loss
            # for t, l in task2loss.items():
            #     loss = sum(v for v in all_gather_list(l.val)
            #                if v is not None) / hvd.size()
            #     task2loss[t] = RunningMeter(f'loss/{t}', loss)
            
            TB_LOGGER.log_scaler_dict({l.name: l.val
                                       for l in task2loss.values()
                                       if l.val is not None})
            TB_LOGGER.step()

            # update model params
            if opts.grad_norm != -1:
                '''
                if global_step % 10 == 0 and not opts.fp16:
                    bias = model.bert.img_embeddings.img_linear.bias
                    weight = model.bert.img_embeddings.img_linear.weight
                    print(f"bnorm: {bias.norm()}")
                    print(f"wnorm: {weight.norm()}")
                    print(f"bgnorm: {bias.grad.norm()}")
                    print(f"wgnorm: {weight.grad.norm()}")

                    mask = model.bert.img_embeddings.mask_embedding.weight
                    print(f"mnorm: {mask.norm()}")
                    print(f"mgnorm: {mask.grad.norm()}")

                    print([(n, p.grad.norm().item())
                           for n, p in model.named_parameters()
                           if p.grad is not None
                              and p.grad.norm().item() > grad_norm/10])
                '''
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            opts.grad_norm)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            optimizer.step()
            optimizer.zero_grad()
            pbar.update(1)

            if global_step % 100 == 0:
                # monitor training throughput
                LOGGER.info(f'==============Step {global_step}===============')
                for t in train_dataloaders.keys():
                    assert all(tt == t for tt in all_gather_list(t))
                    tot_ex = sum(all_gather_list(n_examples[t]))
                    ex_per_sec = int(tot_ex / (time()-start))
                    tot_in = sum(all_gather_list(n_in_units[t]))
                    in_per_sec = int(tot_in / (time()-start))
                    tot_l = sum(all_gather_list(n_loss_units[t]))
                    l_per_sec = int(tot_l / (time()-start))
                    LOGGER.info(f'{t}: {tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec,
                                         global_step)
                    if 'nce' in t:
                        avg_neg = sum(all_gather_list(n_neg_nce[t])
                                      ) / hvd.size() // step
                        LOGGER.info(f'{t}: averaging '
                                    f'{avg_neg} negative samples')
                LOGGER.info(f'===============================================')

            if global_step % opts.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                validate(model, val_dataloaders)
                #os.makedir('/'.join([opts.output_dir, "ckpt")
                model_saver.save(model, global_step, optimizer)
            restorer.step()
        if global_step >= opts.num_train_steps:
            break

    if global_step % opts.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        validate(model, val_dataloaders)
        model_saver.save(model, global_step)
Exemplo n.º 5
0
def main(opts):

    LOGGER.info("16-bits training: {}".format(opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    LOGGER.info("Loading data from {}".format(opts.data_dir))

    # create datasets
    train_set = VizWizDataset(opts.data_dir, split="train")
    val_set = VizWizDataset(opts.data_dir, split="val")

    # data loaders
    train_dataloader = DataLoader(train_set,
                                  shuffle=True,
                                  batch_size=opts.train_batch_size,
                                  num_workers=opts.n_workers,
                                  collate_fn=vizwiz_collate)
    val_dataloader = DataLoader(val_set,
                                shuffle=False,
                                batch_size=opts.train_batch_size,
                                num_workers=opts.n_workers,
                                collate_fn=vizwiz_collate)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    model = VizWizModel.from_pretrained(opts.model_config,
                                        state_dict=checkpoint,
                                        img_dim=IMG_DIM)
    model.init_type_embedding()
    model.to(DEVICE)
    # make sure every process has same model parameters in the beginning\
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)

    # setup logging
    save_training_meta(opts)
    pbar = tqdm(total=opts.num_train_steps)
    model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
    if not os.path.exists(join(opts.output_dir, 'results')):
        os.makedirs(join(opts.output_dir, 'results'))
    add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))

    LOGGER.info(f"***** Running training *****")
    LOGGER.info("  Num examples = %d", len(train_dataloader.dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_answerable_loss = RunningMeter('running_answerable_loss')
    running_answer_loss = RunningMeter('running_answer_loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    global_step = 0

    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()

    while True:

        for step, batch in enumerate(train_dataloader):

            input_ids = batch['qs_tok'].to(DEVICE)
            img_feats = batch["img_feats"].to(DEVICE)
            attn_masks = batch["attn_masks"].to(DEVICE)
            position_ids = batch["position_ids"].to(DEVICE)
            answerable_targets = batch["answerables"].to(DEVICE)
            answer_targets = batch["answers_tok"].to(DEVICE)
            n_examples += input_ids.size(0)

            answerable_loss, answer_loss = model(
                input_ids=input_ids,
                position_ids=position_ids,
                img_feat=img_feats,
                attn_masks=attn_masks,
                gather_index=None,
                answerable_targets=answerable_targets,
                answer_targets=answer_targets,
                compute_loss=True)

            total_loss = answerable_loss + answer_loss
            total_loss.backward()

            running_answerable_loss(answerable_loss.item())
            running_answer_loss(answer_loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step

                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = n_examples
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(
                        f'Step {global_step}: '
                        f'{tot_ex} examples trained at '
                        f'{ex_per_sec} ex/s '
                        f'answerable loss {running_answerable_loss.val:.4f} '
                        f'answer loss {running_answer_loss.val:.4f}')

                if global_step % opts.valid_steps == 0:

                    LOGGER.info(f"Step {global_step}: start running "
                                f"evaluation on val...")
                    validate(model, val_dataloader)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break

        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")

    LOGGER.info(f"Step {global_step}: start running " f"evaluation on val...")
    validate(model, val_dataloader)
    model_saver.save(model, f'{global_step}_final')
Exemplo n.º 6
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    opts.n_gpu = n_gpu
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if hvd.rank() != 0:
        LOGGER.disabled = True
    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading the whole video dataset {opts.sub_txt_db}, "
                f"{opts.vfeat_db}")
    if opts.task != "didemo_video_only":
        video_db = load_video_sub_dataset(opts.vfeat_db, opts.sub_txt_db,
                                          opts.vfeat_interval, opts)
    else:
        txt_meta = load_json(join(opts.train_query_txt_db, "meta.json"))
        video_db = load_video_only_dataset(opts.vfeat_db, txt_meta,
                                           opts.vfeat_interval, opts)

    # data loaders
    # train
    video_ids = get_video_ids(opts.train_query_txt_db)
    train_q_txt_db = QueryTokLmdb(opts.train_query_txt_db, opts.max_txt_len)
    train_dataloaders = build_downstream_dataloaders([opts.task],
                                                     video_db,
                                                     video_ids,
                                                     True,
                                                     opts,
                                                     shuffle=True,
                                                     q_txt_db=train_q_txt_db)
    meta_loader = MetaLoader(train_dataloaders,
                             accum_steps=opts.gradient_accumulation_steps,
                             distributed=n_gpu > 1)
    meta_loader = PrefetchLoader(meta_loader)

    # val
    video_ids = get_video_ids(opts.val_query_txt_db)
    val_q_txt_db = QueryTokLmdb(opts.val_query_txt_db, -1)
    val_dataloaders = build_downstream_dataloaders([opts.task],
                                                   video_db,
                                                   video_ids,
                                                   False,
                                                   opts,
                                                   q_txt_db=val_q_txt_db)

    if opts.task != "didemo_video_only":
        inf_dataset = VcmrFullEvalDataset
    else:
        inf_dataset = VcmrVideoOnlyFullEvalDataset
    LOGGER.info(f"Loading Inference Dataset {opts.val_query_txt_db} (val)")
    val_dset = inf_dataset(video_ids,
                           video_db,
                           val_q_txt_db,
                           distributed=opts.distributed_eval)
    inf_loader_val = DataLoader(val_dset,
                                batch_size=opts.vcmr_eval_q_batch_size,
                                num_workers=opts.n_workers,
                                pin_memory=opts.pin_mem,
                                collate_fn=vcmr_full_eval_collate)
    inf_loader_val = PrefetchLoader(inf_loader_val)
    if opts.test_query_txt_db:
        LOGGER.info(
            f"Loading Inference Dataset {opts.test_query_txt_db} (test)")
        video_ids = get_video_ids(opts.test_query_txt_db)
        test_q_txt_db = QueryTokLmdb(opts.test_query_txt_db, -1)
        test_dset = inf_dataset(video_ids,
                                video_db,
                                test_q_txt_db,
                                distributed=opts.distributed_eval)
        inf_loader_test = DataLoader(test_dset,
                                     batch_size=opts.vcmr_eval_q_batch_size,
                                     num_workers=opts.n_workers,
                                     pin_memory=opts.pin_mem,
                                     collate_fn=vcmr_full_eval_collate)
        inf_loader_test = PrefetchLoader(inf_loader_test)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\
        ".position_embeddings.weight"
    if img_pos_embed_weight_key in checkpoint:
        max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key])
    else:
        max_frm_seq_len = MAX_FRM_SEQ_LEN

    model = HeroForVcmr.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        vfeat_dim=VFEAT_DIM,
        max_frm_seq_len=max_frm_seq_len,
        lw_neg_ctx=opts.lw_neg_ctx,
        lw_neg_q=opts.lw_neg_q,
        lw_st_ed=0,
        ranking_loss_type=opts.ranking_loss_type,
        use_hard_negative=False,
        hard_pool_size=opts.hard_pool_size,
        margin=opts.margin,
        use_all_neg=opts.use_all_neg,
        drop_svmr_prob=opts.drop_svmr_prob)

    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.fp16,
                                      opt_level='O2')
    restorer = TrainingRestorer(opts, model, optimizer)
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        if not exists(join(opts.output_dir, 'results')):
            # store tvr predictions
            os.makedirs(join(opts.output_dir, 'results'))
        if opts.nms_thd != -1:
            # store tvr-nms predictions
            if not exists(join(opts.output_dir, 'results_nms')):
                os.makedirs(join(opts.output_dir, 'results_nms'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    task2loss = {
        task: RunningMeter(f'loss/{task}')
        for task in train_dataloaders.keys()
    }

    for obj in (f'{opts.task}_st_ed', f'{opts.task}_neg_ctx',
                f'{opts.task}_neg_q'):
        task2loss[obj] = RunningMeter(f'loss/{obj}')
    model.train()
    n_examples = defaultdict(int)
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    if global_step == 0:
        optimizer.step()
    for step, (task, batch) in enumerate(meta_loader):
        if len(opts.hard_negtiave_start_step) > 0:
            for i, hn_step in enumerate(opts.hard_negtiave_start_step):
                if global_step >= hn_step and hn_step != -1:
                    model.set_hard_negative(True, opts.hard_pool_size[i],
                                            opts.hard_neg_weights[i])
        if opts.train_span_start_step != -1 and\
                global_step >= opts.train_span_start_step:
            model.set_train_st_ed(opts.lw_st_ed)

        n_examples[task] += opts.train_batch_size

        loss = model(batch, task=task, compute_loss=True)

        loss_st_ed, loss_neg_ctx, loss_neg_q = loss
        loss = loss_st_ed + loss_neg_ctx + loss_neg_q
        for n, ls, w in (('st_ed', loss_st_ed, opts.lw_st_ed),
                         ('neg_ctx', loss_neg_ctx, opts.lw_neg_ctx),
                         ('neg_q', loss_neg_q, opts.lw_neg_q)):
            ls = ls.item()
            if w:
                ls /= w
            task2loss[f'{task}_{n}'](ls)

        loss = loss.mean()
        task2loss[task](loss.item())

        delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
        with amp.scale_loss(loss,
                            optimizer,
                            delay_unscale=delay_unscale,
                            loss_id=task2scaler[task]) as scaled_loss:
            scaled_loss.backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

        if (step + 1) % opts.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            lr_this_step = get_lr_sched(global_step, opts)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

            # log loss
            TB_LOGGER.log_scaler_dict({
                temp_loss.name: temp_loss.val
                for temp_loss in task2loss.values()
                if temp_loss.val is not None
            })
            TB_LOGGER.step()

            # update model params
            if opts.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            opts.grad_norm)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            optimizer.step()
            optimizer.zero_grad()
            pbar.update(1)

            if global_step % 100 == 0:
                # monitor training throughput
                LOGGER.info('-------------------------------------------')
                LOGGER.info(f'Step {global_step}:')
                for t in train_dataloaders.keys():
                    tot_ex = sum(all_gather_list(n_examples[t]))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{t}: {tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
                                         global_step)

            if global_step % opts.valid_steps == 0:
                LOGGER.info('===========================================')
                LOGGER.info(f"Step {global_step}: start running validation")
                validate(model, val_dataloaders, opts)
                if hvd.rank() == 0 or opts.distributed_eval:
                    log, results = validate_full_vcmr(model,
                                                      inf_loader_val,
                                                      'val',
                                                      opts,
                                                      model_opts=opts)
                    save_json(
                        results, f'{opts.output_dir}/results/'
                        f'val_results_{global_step}_rank{hvd.rank()}.json')
                    TB_LOGGER.log_scaler_dict(log)
                    if opts.test_query_txt_db:
                        log, results = validate_full_vcmr(model,
                                                          inf_loader_test,
                                                          'test',
                                                          opts,
                                                          model_opts=opts)
                        save_json(
                            results, f'{opts.output_dir}/results/'
                            f'test_results_{global_step}_rank{hvd.rank()}.json'
                        )
                        TB_LOGGER.log_scaler_dict(log)
                LOGGER.info('===========================================')
                model_saver.save(model, global_step)

            # step restorer in the end to prevent missing validation checkpoint
            restorer.step()
        if global_step >= opts.num_train_steps:
            break

    LOGGER.info('===========================================')
    if global_step % opts.valid_steps != 0:
        if hvd.rank() == 0 or opts.distributed_eval:
            log, results = validate_full_vcmr(model,
                                              inf_loader_val,
                                              'val',
                                              opts,
                                              model_opts=opts)
            save_json(
                results, f'{opts.output_dir}/results/'
                f'val_results_{global_step}'
                f'_rank{hvd.rank()}_final.json')
            TB_LOGGER.log_scaler_dict(log)
            if opts.test_query_txt_db:
                log, results = validate_full_vcmr(model,
                                                  inf_loader_test,
                                                  'test',
                                                  opts,
                                                  model_opts=opts)
                save_json(
                    results, f'{opts.output_dir}/results/'
                    f'test_results_{global_step}_rank{hvd.rank()}.json')
                TB_LOGGER.log_scaler_dict(log)
    model_saver.save(model, f'{global_step}_final')
Exemplo n.º 7
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
                f"{opts.train_img_db}")
    train_dataloader = create_dataloader(
        opts.train_img_db,
        opts.train_txt_db,
        opts.train_batch_size,
        True,
        VeDataset,
        ve_collate,
        opts,
    )
    val_dataloader = create_dataloader(
        opts.val_img_db,
        opts.val_txt_db,
        opts.val_batch_size,
        False,
        VeEvalDataset,
        ve_eval_collate,
        opts,
    )
    test_dataloader = create_dataloader(
        opts.test_img_db,
        opts.test_txt_db,
        opts.val_batch_size,
        False,
        VeEvalDataset,
        ve_eval_collate,
        opts,
    )

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    bert_model = json.load(open(f"{opts.train_txt_db}/meta.json"))["bert"]
    if "bert" not in bert_model:
        bert_model = "bert-large-cased"  # quick hack for glove exp
    model = UniterForVisualEntailment.from_pretrained(opts.model_config,
                                                      state_dict=checkpoint,
                                                      img_dim=IMG_DIM)
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level="O2")

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, "log"))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, "ckpt"))
        pickle.dump(ans2label,
                    open(join(opts.output_dir, "ckpt", "ans2label.pkl"), "wb"))
        os.makedirs(join(opts.output_dir, "results"))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataloader.dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter("loss")
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch["input_ids"].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean() * batch["targets"].size(1)  # instance-leval bce
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                TB_LOGGER.add_scalar("lr", lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar("loss", running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar("grad_norm", grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f"============Step {global_step}=============")
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f"{tot_ex} examples trained at "
                                f"{ex_per_sec} ex/s")
                    TB_LOGGER.add_scalar("perf/ex_per_s", ex_per_sec,
                                         global_step)
                    LOGGER.info(f"===========================================")

                if global_step % opts.valid_steps == 0:
                    for split, loader in [
                        ("val", val_dataloader),
                        ("test", test_dataloader),
                    ]:
                        LOGGER.info(f"Step {global_step}: start running "
                                    f"validation on {split} split...")
                        val_log, results = validate(model, loader, label2ans,
                                                    split)
                        with open(
                                f"{opts.output_dir}/results/"
                                f"{split}_results_{global_step}_"
                                f"rank{rank}.json",
                                "w",
                        ) as f:
                            json.dump(results, f)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        for split, loader in [("val", val_dataloader),
                              ("test", test_dataloader)]:
            LOGGER.info(f"Step {global_step}: start running "
                        f"validation on {split} split...")
            val_log, results = validate(model, loader, label2ans, split)
            with open(
                    f"{opts.output_dir}/results/"
                    f"{split}_results_{global_step}_"
                    f"rank{rank}_final.json",
                    "w",
            ) as f:
                json.dump(results, f)
            TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)
Exemplo n.º 8
0
class TrainerTemplate():
    def __init__(self, config):
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.short_loss_list, self.id_list = [], [], [], [], [], []
        self.best_val_metrics, self.train_metrics = defaultdict(int), {}
        self.best_auc = 0
        self.not_improved = 0
        self.best_val_loss = 1000
        self.total_iters = 0
        self.terminate_training = False
        self.model_file = os.path.join(config['model_path'],
                                       config['model_save_name'])
        self.pretrained_model_file = None
        if config['pretrained_model_file'] is not None:
            self.pretrained_model_file = os.path.join(
                config['model_path'], config['pretrained_model_file'])
        self.start_epoch = 1
        self.config = config
        self.device = get_device()

        if not isinstance(self.config['test_loader'], list):
            self.config['test_loader'] = [self.config['test_loader']]

        # Initialize the model, optimizer and loss function
        self.init_training_params()

    def init_training_params(self):
        self.init_model()
        self.model_saver = ModelSaver(self.model_file)

        if self.config['parallel_computing']:
            self.model = nn.DataParallel(self.model)

        self.init_optimizer()
        self.init_scheduler()

        if self.config['loss_func'] == 'bce_logits':
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(
                [self.config['pos_wt']]).to(self.device))
        elif self.config['loss_func'] == 'bce':
            self.criterion = nn.BCELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()

    def init_scheduler(self):
        if self.config['scheduler'] == 'step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=self.config['lr_decay_step'],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'multi_step':
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer,
                milestones=[5, 10, 15, 25, 40],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'warmup':
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])
        elif self.config['scheduler'] == 'warmup_cosine':
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])

    def init_optimizer(self):
        self.optimizer = get_optimizer(self.model, self.config)

    def average_gradients(self, steps):
        for param in self.model.parameters():
            if param.requires_grad and param.grad is not None:
                param.grad = param.grad / steps

    def calculate_loss(self, preds, batch_label, grad_step):
        if self.config['loss_func'] == 'bce':
            preds = torch.sigmoid(preds)
        preds = preds.squeeze(1).to(
            self.device
        ) if self.config['loss_func'] == 'bce_logits' else preds.to(
            self.device)
        loss = self.criterion(
            preds,
            batch_label.to(self.device) if self.config['loss_func'] == 'ce'
            else batch_label.float().to(self.device))

        if grad_step and self.iters % self.config['gradient_accumulation'] == 0:
            loss.backward()
            self.average_gradients(steps=self.config['gradient_accumulation'])
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config['max_grad_norm'])
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
        elif grad_step:
            loss.backward()

        if self.config['loss_func'] == 'bce':
            probs = preds
            preds = (preds > 0.5).type(torch.FloatTensor)
        elif self.config['loss_func'] == 'ce':
            probs = F.softmax(preds, dim=1)
            preds = torch.argmax(probs, dim=1)
        elif self.config['loss_func'] == 'bce_logits':
            probs = torch.sigmoid(preds)
            preds = (probs > 0.5).type(torch.FloatTensor)

        self.probs_list.append(probs.cpu().detach().numpy())
        self.preds_list.append(preds.cpu().detach().numpy())
        self.labels_list.append(batch_label.cpu().detach().numpy())
        self.loss_list.append(loss.detach().item())
        if grad_step:
            self.short_loss_list.append(loss.detach().item())

    def eval_model(self, test=False, test_idx=0):
        self.model.eval()
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.id_list = [], [], [], [], []
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        with torch.no_grad():
            for iters, batch in enumerate(batch_loader):
                batch = self.batch_to_device(batch)
                if batch_loader.dataset.return_ids:
                    self.id_list.append(batch['ids'])
                self.eval_iter_step(iters, batch, test=test)

            self.probs_list = [
                prob for batch_prob in self.probs_list for prob in batch_prob
            ]
            self.preds_list = [
                pred for batch_pred in self.preds_list for pred in batch_pred
            ]
            self.labels_list = [
                label for batch_labels in self.labels_list
                for label in batch_labels
            ]
            self.id_list = [
                data_id for batch_id in self.id_list for data_id in batch_id
            ]

            val_loss = sum(self.loss_list) / len(self.loss_list)
            eval_metrics = standard_metrics(torch.tensor(self.probs_list),
                                            torch.tensor(self.labels_list),
                                            add_optimal_acc=True)
            # if test:
            # 	print(classification_report(np.array(self.labels_list), np.array(self.preds_list)))

        return eval_metrics, val_loss

    @torch.no_grad()
    def export_test_predictions(self, test_idx=0, threshold=0.5):
        self.model.eval()

        ## Step 2: Run model on the test set (no loss!)
        # Ensure that ids are actually returned
        assert self.config['test_loader'][
            test_idx].dataset.return_ids, "Can only export test results if the IDs are returned in the test dataset."
        test_name = self.config['test_loader'][test_idx].dataset.name

        prob_list = []
        id_list = []
        for iters, batch in enumerate(self.config['test_loader'][test_idx]):
            batch = self.batch_to_device(batch)
            id_list.append(batch['ids'].cpu())
            probs = self.test_iter_step(batch)
            if self.config['loss_func'] == 'bce_logits':
                probs = torch.sigmoid(probs)
            prob_list.append(probs.detach().cpu())

        probs = torch.cat(prob_list, dim=0)
        ids = torch.cat(id_list, dim=0)
        preds = (probs > threshold).long()

        ## Step 3: Export predictions
        self._export_preds(ids,
                           probs,
                           preds,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of test predictions")

    @torch.no_grad()
    def export_val_predictions(self, test=False, test_idx=0, threshold=0.5):
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        test_name = batch_loader.dataset.name
        LOGGER.info("Exporting %s predictions..." % (test_name))
        self.model.eval()

        ## Step 1: Find the optimal threshold on validation set
        _, _ = self.eval_model(test=test, test_idx=test_idx)
        val_probs = torch.tensor(self.probs_list)
        val_labels = torch.tensor(self.labels_list)
        if len(self.id_list) != 0:
            val_ids = torch.tensor(self.id_list)
        else:
            val_ids = torch.zeros_like(val_labels) - 1
        val_preds = (val_probs > threshold).long()

        self._export_preds(val_ids,
                           val_probs,
                           val_preds,
                           labels=val_labels,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of %s predictions" % test_name)

    def _export_preds(self,
                      ids,
                      probs,
                      preds,
                      labels=None,
                      file_postfix="_preds.csv"):
        file_string = "id,proba,label%s\n" % (",gt"
                                              if labels is not None else "")
        for i in range(ids.shape[0]):
            file_string += "%i,%f,%i" % (ids[i].item(), probs[i].item(),
                                         preds[i].item())
            if labels is not None:
                file_string += ",%i" % labels[i].item()
            file_string += "\n"
        filepath = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + file_postfix)
        with open(filepath, "w") as f:
            f.write(file_string)

    def check_early_stopping(self):
        self.this_metric = self.val_loss if self.config[
            'optimize_for'] == 'loss' else self.val_metrics[
                self.config['optimize_for']]
        self.current_best = self.best_val_loss if self.config[
            'optimize_for'] == 'loss' else self.best_val_metrics[
                self.config['optimize_for']]

        new_best = self.this_metric < self.current_best if self.config[
            'optimize_for'] == 'loss' else self.this_metric > self.current_best
        if new_best:
            LOGGER.info("New High Score! Saving model...")
            self.best_val_metrics = self.val_metrics
            self.best_val_loss = self.val_loss
            if not self.config["no_model_checkpoints"]:
                self.model_saver.save(self.model)

        ### Stopping Criteria based on patience and change-in-metric-threshold ###
        diff = self.current_best - self.this_metric if self.config[
            'optimize_for'] == 'loss' else self.this_metric - self.current_best
        if diff < self.config['early_stop_thresh']:
            self.not_improved += 1
            if self.not_improved >= self.config['patience']:
                self.terminate_training = True
        else:
            self.not_improved = 0
        LOGGER.info("current patience: {}".format(self.not_improved))

    def train_epoch_step(self):
        self.model.train()
        lr = self.scheduler.get_last_lr()
        self.total_iters += self.iters + 1
        self.probs_list = [
            pred for batch_pred in self.probs_list for pred in batch_pred
        ]
        self.labels_list = [
            label for batch_labels in self.labels_list
            for label in batch_labels
        ]

        # Evaluate on train set
        self.train_metrics = standard_metrics(torch.tensor(self.probs_list),
                                              torch.tensor(self.labels_list),
                                              add_optimal_acc=True)
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.loss_list,
                        self.train_metrics,
                        lr[0],
                        loss_only=False,
                        val=False)
        self.train_loss = self.loss_list[:]

        # Evaluate on dev set
        val_time = time.time()
        self.val_metrics, self.val_loss = self.eval_model()
        self.config['writer'].add_scalar("Stats/time_validation",
                                         time.time() - val_time,
                                         self.total_iters)

        # print stats
        print_stats(self.config, self.epoch, self.train_metrics,
                    self.train_loss, self.val_metrics, self.val_loss,
                    self.start, lr[0])

        # log validation stats in tensorboard
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.val_loss,
                        self.val_metrics,
                        lr[0],
                        loss_only=False,
                        val=True)

        # Check for early stopping criteria
        self.check_early_stopping()
        self.probs_list = []
        self.preds_list = []
        self.labels_list = []
        self.loss_list = []
        self.id_list = []

        self.train_loss = sum(self.train_loss) / len(self.train_loss)
        del self.val_metrics
        del self.val_loss

    def end_training(self):
        # Termination message
        print("\n" + "-" * 100)
        if self.terminate_training:
            LOGGER.info(
                "Training terminated early because the Validation {} did not improve for   {}   epochs"
                .format(self.config['optimize_for'], self.config['patience']))
        else:
            LOGGER.info(
                "Maximum epochs of {} reached. Finished training !!".format(
                    self.config['max_epoch']))

        print_test_stats(self.best_val_metrics, test=False)

        print("-" * 50 + "\n\t\tEvaluating on test set\n" + "-" * 50)
        if not self.config["no_model_checkpoints"]:
            if os.path.isfile(self.model_file):
                self.load_model()
                self.model.to(self.device)
            else:
                raise ValueError(
                    "No Saved model state_dict found for the chosen model...!!! \nAborting evaluation on test set..."
                    .format(self.config['model_name']))

            self.export_val_predictions(
            )  # Runs evaluation, no need to run it again here
            val_probs = torch.tensor(self.probs_list)
            val_labels = torch.tensor(self.labels_list)
            threshold = find_optimal_threshold(val_probs,
                                               val_labels,
                                               metric="accuracy",
                                               show_plot=False)
            best_val_metrics = standard_metrics(val_probs,
                                                val_labels,
                                                threshold=threshold,
                                                add_aucroc=False)
            LOGGER.info(
                "Optimal threshold on validation dataset: %.4f (accuracy=%4.2f%%)"
                % (threshold, 100.0 * best_val_metrics["accuracy"]))

            ## Testing is in the standard form not possible, as we do not have any labels (gives an error in standard_metrics)
            ## Instead, we should write out the predictions in the form of the leaderboard
            self.test_metrics = dict()
            for test_idx in range(len(self.config['test_loader'])):
                test_name = self.config['test_loader'][test_idx].dataset.name
                LOGGER.info("Export and testing on %s..." % test_name)
                if hasattr(self.config['test_loader'][test_idx].dataset, "data") and \
                   hasattr(self.config['test_loader'][test_idx].dataset.data, "labels") and \
                   self.config['test_loader'][test_idx].dataset.data.labels[0] == -1:## Step 1: Find the optimal threshold on validation set
                    self.export_test_predictions(test_idx=test_idx,
                                                 threshold=threshold)
                    self.test_metrics[test_name] = dict()
                else:
                    test_idx_metrics, _ = self.eval_model(test=True,
                                                          test_idx=test_idx)
                    self.test_metrics[test_name] = test_idx_metrics
                    print_test_stats(test_idx_metrics, test=True)
                    self.export_val_predictions(test=True,
                                                test_idx=test_idx,
                                                threshold=threshold)
        else:
            LOGGER.info(
                "No model checkpoints were saved. Hence, testing will be skipped."
            )
            self.test_metrics = dict()

        self.export_metrics()

        self.config['writer'].close()

        if self.config['remove_checkpoints']:
            LOGGER.info("Removing checkpoint %s..." % self.model_file)
            os.remove(self.model_file)

    def export_metrics(self):
        metric_export_file = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + "_metrics.json")
        metric_dict = {}
        metric_dict["dev"] = self.best_val_metrics
        metric_dict["dev"]["loss"] = self.best_val_loss
        metric_dict["train"] = self.train_metrics
        metric_dict["train"]["loss"] = sum(
            self.train_loss) / len(self.train_loss) if isinstance(
                self.train_loss, list) else self.train_loss
        if hasattr(self, "test_metrics") and len(self.test_metrics) > 0:
            metric_dict["test"] = self.test_metrics

        with open(metric_export_file, "w") as f:
            json.dump(metric_dict, f, indent=4)

    def train_main(self, cache=False):
        print("\n\n" + "=" * 100 + "\n\t\t\t\t\t Training Network\n" +
              "=" * 100)

        self.start = time.time()
        print("\nBeginning training at:  {} \n".format(
            datetime.datetime.now()))

        self.model.to(self.device)

        for self.epoch in range(self.start_epoch,
                                self.config['max_epoch'] + 1):
            train_times = []
            for self.iters, self.batch in enumerate(
                    self.config['train_loader']):
                self.model.train()

                iter_time = time.time()
                self.batch = self.batch_to_device(self.batch)
                self.train_iter_step()
                train_times.append(time.time() - iter_time)

                # Loss only
                if (self.total_iters + self.iters +
                        1) % self.config['log_every'] == 0:
                    ## Uncomment line below for debugging
                    if self.config['debug']:
                        LOGGER.info(
                            "Logging tensorboard at step %i with %i values" %
                            (self.iters + self.total_iters + 1,
                             len(self.short_loss_list)))
                    log_tensorboard(self.config,
                                    self.config['writer'],
                                    self.model,
                                    self.epoch,
                                    self.iters,
                                    self.total_iters,
                                    self.short_loss_list,
                                    loss_only=True,
                                    val=False)
                    self.config['writer'].add_scalar(
                        'Stats/time_per_train_iter', mean(train_times),
                        (self.iters + self.total_iters + 1))
                    self.config['writer'].add_scalar(
                        'Stats/learning_rate',
                        self.scheduler.get_last_lr()[0],
                        (self.iters + self.total_iters + 1))
                    train_times = []
                    self.short_loss_list = []
            self.train_epoch_step()

            if self.terminate_training:
                break

        self.end_training()
        return self.best_val_metrics, self.test_metrics

    def batch_to_device(self, batch):
        batch = {
            k: (v.to(self.device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        return batch

    def init_model(self):
        raise NotImplementedError

    def load_model(self):
        raise NotImplementedError

    def train_iter_step(self):
        raise NotImplementedError

    def eval_iter_step(self, iters, batch, test):
        raise NotImplementedError

    def test_iter_step(self, batch):
        raise NotImplementedError

    @staticmethod
    def add_default_argparse(parser, defaults=dict()):
        # Required Paths
        parser.add_argument(
            '--data_path',
            type=str,
            default='./dataset',
            help='path to dataset folder that contains the processed data files'
        )
        parser.add_argument(
            '--model_path',
            type=str,
            default='./model_checkpoints',
            help='Directory for saving trained model checkpoints')
        parser.add_argument(
            '--vis_path',
            type=str,
            default='./vis_checkpoints',
            help='Directory for saving tensorboard checkpoints')
        parser.add_argument("--model_save_name",
                            type=str,
                            default='best_model.pt',
                            help='saved model name')
        parser.add_argument(
            "--no_model_checkpoints",
            action="store_true",
            help=
            'If selected, no model checkpoints will be created, and no testing performed (for gridsearches etc.)'
        )
        parser.add_argument(
            "--remove_checkpoints",
            action="store_true",
            help=
            'If selected, model checkpoints will be deleted after finishing testing.'
        )
        parser.add_argument(
            '--debug',
            action="store_true",
            help=
            'This option is intended for tests on local machines, and more output.'
        )

        # Load pretrained model
        parser.add_argument('--pretrained_model_file',
                            type=str,
                            help='Name of the pretrained model')

        ## Training parameters

        # Named parameters
        parser.add_argument(
            '--optimizer',
            type=str,
            default=defaults.get('optimizer', 'adam'),
            help='Optimizer to use for training: adam / adamx / adamw')
        ## Not sure whether we should have this here. For a multi-task setup, we need our own loss functions
        parser.add_argument(
            '--loss_func',
            type=str,
            default=defaults.get('loss_func', 'bce_logits'),
            help='Loss function to use for optimization: bce / bce_logits / ce'
        )
        parser.add_argument(
            '--optimize_for',
            type=str,
            default=defaults.get('optimize_for', 'aucroc'),
            help=
            'Optimize for what measure during training and early stopping: loss / F1 / aucroc / accuracy'
        )
        parser.add_argument(
            '--scheduler',
            type=str,
            default=defaults.get('scheduler', 'warmup_cosine'),
            help=
            'The type of lr scheduler to use anneal learning rate: step/multi_step/warmup/warmp_cosine'
        )

        # Numerical parameters
        parser.add_argument(
            '--confounder_repeat',
            type=int,
            default=defaults.get('confounder_repeat', 1),
            help=
            "Factor with which we should repeat the (hard) text confounding examples during training."
        )
        parser.add_argument(
            '--object_conf_thresh',
            type=float,
            default=defaults.get('object_conf_thresh', 0.0),
            help=
            "Confidence threshold for object bounding boxes. Boxes with lower confidence are ignored."
        )
        parser.add_argument(
            '--num_folds',
            type=int,
            default=defaults.get('num_folds', 0),
            help=
            'Number of folds to use during training. 0 means the default split, -1 means all splits'
        )
        parser.add_argument(
            '--crossval_dev_size',
            type=int,
            default=defaults.get('crossval_dev_size', 300),
            help=
            'Size of the development folds used in cross validation. Default: 300 (150 positive, 150 negative).'
        )
        parser.add_argument(
            '--crossval_use_dev',
            action="store_true",
            help=
            'If selected, the dev_seen set is incorporated into the cross validation splits.'
        )
        parser.add_argument('--beta1',
                            type=float,
                            default=defaults.get('beta1', 0.9),
                            help='beta1 parameter in Adam optimizer')
        parser.add_argument('--beta2',
                            type=float,
                            default=defaults.get('beta2', 0.999),
                            help='beta2 parameter in Adam optimizer')
        parser.add_argument('--batch_size',
                            type=int,
                            default=defaults.get('batch_size', 8),
                            help='batch size for training')
        parser.add_argument('--num_workers',
                            type=int,
                            default=defaults.get('num_workers', 0),
                            help='Number of workers to start per dataset')
        parser.add_argument(
            '--gradient_accumulation',
            type=int,
            default=defaults.get('gradient_accumulation', 1),
            help=
            'No. of update steps to accumulate before performing backward pass'
        )
        parser.add_argument('--max_grad_norm',
                            type=int,
                            default=defaults.get('max_grad_norm', 5),
                            help='max gradient norm for gradient clipping')
        parser.add_argument(
            '--pos_wt',
            type=float,
            default=defaults.get('pos_wt', 1),
            help=
            'Loss reweighting for the positive class to deal with class imbalance'
        )
        parser.add_argument('--lr',
                            type=float,
                            default=defaults.get('lr', 1e-4),
                            help='Learning rate for training')
        parser.add_argument(
            '--warmup_steps',
            type=int,
            default=defaults.get('warmup_steps', 50),
            help='No. of steps to perform linear lr warmup for')
        parser.add_argument('--weight_decay',
                            type=float,
                            default=defaults.get('weight_decay', 1e-3),
                            help='weight decay for optimizer')
        parser.add_argument('--max_epoch',
                            type=int,
                            default=defaults.get('max_epoch', 20),
                            help='Max epochs to train for')
        parser.add_argument(
            '--lr_decay_step',
            type=float,
            default=defaults.get('lr_decay_step', 3),
            help='No. of epochs after which learning rate should be decreased')
        parser.add_argument(
            '--lr_decay_factor',
            type=float,
            default=defaults.get('lr_decay_factor', 0.8),
            help=
            'Decay the learning rate of the optimizer by this multiplicative amount'
        )
        parser.add_argument('--patience',
                            type=float,
                            default=defaults.get('patience', 5),
                            help='Patience no. of epochs for early stopping')
        parser.add_argument('--early_stop_thresh',
                            type=float,
                            default=defaults.get('early_stop_thresh', 1e-3),
                            help='Patience no. of epochs for early stopping')
        parser.add_argument('--seed',
                            type=int,
                            default=defaults.get('seed', 42),
                            help='set seed for reproducability')
        parser.add_argument(
            '--log_every',
            type=int,
            default=defaults.get('log_every', 2000),
            help=
            'Log stats in Tensorboard every x iterations (not epochs) of training'
        )

        # Options params
        parser.add_argument('--parallel_computing',
                            type=bool,
                            default=defaults.get('parallel_computing', False),
                            help='To run the model on multiple GPUs')

    @staticmethod
    def preprocess_args(config):
        config['device'] = get_device()
        config['n_classes'] = 2 if config['loss_func'] == 'ce' else 1

        # Check all provided paths:
        if not os.path.exists(config['data_path']):
            raise ValueError("[!] ERROR: Dataset path does not exist")
        else:
            LOGGER.info("Data path checked..")
        if not os.path.exists(config['model_path']):
            LOGGER.warning(
                "Creating checkpoint path for saved models at:  {}\n".format(
                    config['model_path']))
            os.makedirs(config['model_path'])
        else:
            LOGGER.info("Model save path checked..")
        if 'config' in config:
            if not os.path.isfile(config['config']):
                raise ValueError("[!] ERROR: config JSON path does not exist")
            else:
                LOGGER.info("config JSON path checked..")
        if not os.path.exists(config['vis_path']):
            LOGGER.warning(
                "Creating checkpoint path for Tensorboard visualizations at:  {}\n"
                .format(config['vis_path']))
            os.makedirs(config['vis_path'])
        else:
            LOGGER.info("Tensorboard Visualization path checked..")
            LOGGER.info(
                "Cleaning Visualization path of older tensorboard files...\n")
            # shutil.rmtree(config['vis_path'])

        # Print args
        print("\n" + "x" * 50 +
              "\n\nRunning training with the following parameters: \n")
        for key, value in config.items():
            if not key.endswith('transf'):
                print(key + ' : ' + str(value))
        print("\n" + "x" * 50)

        # config['vis_path'] = os.path.join(config['vis_path'], '{}_conf{}'.format(config['pretrained_model_file'], config['confounder_repeat']))
        config['writer'] = SummaryWriter(config['vis_path'])

        set_seed(config['seed'])
        return config
Exemplo n.º 9
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db, img_db_gt = load_img_feat(img_path, all_img_dbs, opts)
        qa_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qa")
        qar_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qar")
        train_datasets.append(
            VcrDataset(qa_txt_db, img_db_gt=img_db_gt, img_db=img_db))
        train_datasets.append(
            VcrDataset(qar_txt_db, img_db_gt=img_db_gt, img_db=img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vcr_collate, True, opts)
    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db, val_img_db_gt = load_img_feat(opts.val_img_db, all_img_dbs, opts)
    val_txt_db = VcrTxtTokLmdb(opts.val_txt_db, -1, task="qa")
    val_dataset = VcrEvalDataset(
        "val", val_txt_db, img_db=val_img_db, img_db_gt=val_img_db_gt)
    val_final_dataset = VcrEvalDataset(
        ##"test"
        "val", val_txt_db, img_db=val_img_db, img_db_gt=val_img_db_gt)
    val_dataloader = build_dataloader(val_dataset, vcr_eval_collate,
                                      False, opts)
    val_final_dataloader = build_dataloader(
        val_final_dataset, vcr_eval_collate,
        False, opts)

    # Prepare model
    if opts.checkpoint and opts.checkpoint_from == "pretrain":
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        opts.model_config, checkpoint, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_type_embedding_know()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if opts.checkpoint_from == "vcr_pretrain":
        checkpoint = torch.load(opts.checkpoint)
        state_dict = checkpoint.get('model_state', checkpoint)
        matched_state_dict = {}
        unexpected_keys = set()
        missing_keys = set()
        for name, param in model.named_parameters():
            missing_keys.add(name)
        for key, data in state_dict.items():
            if key in missing_keys:
                matched_state_dict[key] = data
                missing_keys.remove(key)
            else:
                unexpected_keys.add(key)
        print("Unexpected_keys:", list(unexpected_keys))
        print("Missing_keys:", list(missing_keys))
        model.load_state_dict(matched_state_dict, strict=False)
    del checkpoint
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model, optimizer,
                                      enabled=opts.fp16, opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
                                ) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [p.grad.data for p in model.parameters()
                             if p.requires_grad and p.grad is not None]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time()-start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s',
                                         ex_per_sec, global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(
                        model, val_dataloader)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if global_step % opts.valid_steps != 0:
        val_log, results = validate(
            model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
    val_log, results = validate(model, val_final_dataloader)
    with open(f'{opts.output_dir}/results/'
              f'results_{global_step}_final_qa_qar_'
              f'rank{rank}.json', 'w') as f:
        json.dump(results, f)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, global_step)
Exemplo n.º 10
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db, img_db_gt = load_img_feat(img_path, all_img_dbs, opts)
        qa_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qa")
        qar_txt_db = VcrTxtTokLmdb(txt_path, opts.max_txt_len, task="qar")
        train_datasets.append(
            VcrDataset(qa_txt_db, img_db_gt=img_db_gt, img_db=img_db))
        train_datasets.append(
            VcrDataset(qar_txt_db, img_db_gt=img_db_gt, img_db=img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vcr_collate, True, opts)
    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db, val_img_db_gt = load_img_feat(opts.val_img_db, all_img_dbs,
                                              opts)
    val_txt_db = VcrTxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = VcrEvalDataset("val",
                                 val_txt_db,
                                 img_db=val_img_db,
                                 img_db_gt=val_img_db_gt)
    val_final_dataset = VcrEvalDataset("test",
                                       val_txt_db,
                                       img_db=val_img_db,
                                       img_db_gt=val_img_db_gt)
    val_dataloader = build_dataloader(val_dataset, vcr_eval_collate, False,
                                      opts)
    val_final_dataloader = build_dataloader(val_final_dataset,
                                            vcr_eval_collate, False, opts)

    # Prepare model
    if opts.checkpoint and opts.checkpoint_from == "pretrain":
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()}
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualCommonsenseReasoning.from_pretrained(
        opts.model_config, checkpoint, img_dim=IMG_DIM)
    model.init_type_embedding()
    model.init_word_embedding(NUM_SPECIAL_TOKENS)
    if opts.checkpoint_from == "vcr_pretrain":
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()}
        state_dict = checkpoint.get('model_state', checkpoint)
        matched_state_dict = {}
        unexpected_keys = set()
        missing_keys = set()
        for name, param in model.named_parameters():
            missing_keys.add(name)
        for key, data in state_dict.items():
            if key in missing_keys:
                matched_state_dict[key] = data
                missing_keys.remove(key)
            else:
                unexpected_keys.add(key)
        print("Unexpected_keys:", list(unexpected_keys))
        print("Missing_keys:", list(missing_keys))
        model.load_state_dict(matched_state_dict, strict=False)
    del checkpoint
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            # ============= Code for adversarial training =============
            if opts.adv_training:
                # initialize delta
                txt_embeds_init = model.uniter.embeddings.word_embeddings(
                    batch['input_ids'])
                img_embeds_init = batch['img_feat']

                # for simplicity, we initialize the delta as zero vectors, which performs
                # very simliar as initializing randomly using norm or uniform distributions
                txt_delta = torch.zeros_like(txt_embeds_init)
                img_delta = torch.zeros_like(img_embeds_init)

                # calculate the prob. scores for clean samples
                gt_answer_scores = model(batch, compute_loss=False)
                gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                # the main loop
                for astep in range(opts.adv_steps):
                    # (0) forward
                    if opts.adv_modality == ["text"]:
                        txt_delta.requires_grad_()
                        img_delta = torch.zeros_like(img_embeds_init)
                    elif opts.adv_modality == ["image"]:
                        img_delta.requires_grad_()
                        txt_delta = torch.zeros_like(txt_embeds_init)
                    else:
                        txt_delta.requires_grad_()
                        img_delta.requires_grad_()

                    if "alter" not in opts.adv_modality:
                        answer_scores = model(batch,
                                              adv_training=True,
                                              adv_modality=opts.adv_modality,
                                              adv_delta_txt=txt_delta,
                                              adv_delta_img=img_delta,
                                              compute_loss=False)

                        # CE loss
                        ce_loss = F.cross_entropy(answer_scores,
                                                  batch['targets'].squeeze(-1),
                                                  reduction='mean')

                        # KL loss
                        answer_prob = F.softmax(answer_scores, dim=1)
                        answer_logprob = F.log_softmax(answer_scores, dim=1)
                        kl_loss = F.kl_div(
                            answer_logprob, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob,
                                reduction='none')
                        kl_loss = kl_loss.mean()

                        # (1) backward
                        loss = (ce_loss +
                                opts.adv_kl_weight * kl_loss) / opts.adv_steps
                    else:
                        answer_scores_1 = model(batch,
                                                adv_training=True,
                                                adv_modality=["text"],
                                                adv_delta_txt=txt_delta,
                                                adv_delta_img=None,
                                                compute_loss=False)

                        # CE loss
                        ce_loss_1 = F.cross_entropy(
                            answer_scores,
                            batch['targets'].squeeze(-1),
                            reduction='mean')

                        answer_scores_2 = model(batch,
                                                adv_training=True,
                                                adv_modality=["image"],
                                                adv_delta_txt=None,
                                                adv_delta_img=img_delta,
                                                compute_loss=False)

                        # CE loss
                        ce_loss_2 = F.cross_entropy(
                            answer_scores,
                            batch['targets'].squeeze(-1),
                            reduction='mean')

                        # KL loss
                        answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                        answer_logprob_1 = F.log_softmax(answer_scores_1,
                                                         dim=1)
                        answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                        answer_logprob_2 = F.log_softmax(answer_scores_2,
                                                         dim=1)
                        kl_loss_1 = F.kl_div(
                            answer_logprob_1, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob_1,
                                reduction='none')
                        kl_loss_1 = kl_loss_1.mean()
                        kl_loss_2 = F.kl_div(
                            answer_logprob_2, gt_answer_prob, reduction='none') + \
                            F.kl_div(
                                gt_answer_logprob, answer_prob_2,
                                reduction='none')
                        kl_loss_2 = kl_loss_2.mean()

                        # (1) backward
                        loss = (ce_loss_1 + ce_loss_2 + opts.adv_kl_weight *
                                (kl_loss_1 + kl_loss_2)) / (opts.adv_steps * 2)

                    delay_unscale = (
                        (step + 1) % opts.gradient_accumulation_steps !=
                        0) or ((astep + 1) % opts.adv_steps != 0)
                    with amp.scale_loss(
                            loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling
                            # to make sure every process uses
                            # the same gradient scale
                            grads = [
                                p.grad.data for p in model.parameters()
                                if p.requires_grad and p.grad is not None
                            ]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())

                    if astep == opts.adv_steps - 1:
                        # further updates on delta
                        break

                    # (2) get gradient on delta
                    # fix fp16 problem
                    amp_scale = scaled_loss.item() // loss.item()
                    if "text" in opts.adv_modality:
                        txt_delta_grad = txt_delta.grad.clone().detach()
                        txt_delta_grad = txt_delta_grad.float() / amp_scale
                    if "image" in opts.adv_modality:
                        img_delta_grad = img_delta.grad.clone().detach()
                        img_delta_grad = img_delta_grad.float() / amp_scale

                    # (3) update and clip for txt delta
                    if "text" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(txt_delta.view(
                                    txt_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(txt_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                txt_delta = (txt_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(txt_delta_grad.view(
                                txt_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt *
                                              txt_delta_grad /
                                              denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                txt_delta = torch.clamp(
                                    txt_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

                    # (4) update and clip for image delta
                    if "image" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(img_delta.view(
                                    img_delta.size(0), -1),
                                                        p=2,
                                                        dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm
                                               ).to(img_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm *
                                             exceed_mask +
                                             (1 - exceed_mask)).view(-1, 1, 1)
                                img_delta = (img_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(img_delta_grad.view(
                                img_delta_grad.size(0), -1),
                                                dim=1,
                                                p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img *
                                              img_delta_grad /
                                              denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                img_delta = torch.clamp(
                                    img_delta, -opts.adv_max_norm,
                                    opts.adv_max_norm).detach()

            else:
                loss = model(batch, compute_loss=True)
                loss = loss.mean()
                delay_unscale = ((step + 1) % opts.gradient_accumulation_steps
                                 != 0)
                with amp.scale_loss(
                        loss, optimizer,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
                    if not delay_unscale:
                        # gather gradients from every processes
                        # do this before unscaling to make sure every process uses
                        # the same gradient scale
                        grads = [
                            p.grad.data for p in model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        all_reduce_and_rescale_tensors(grads, float(1))

                running_loss(loss.item())

            # ============================ End ==========================

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info('===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(model, val_dataloader)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if global_step % opts.valid_steps != 0:
        val_log, results = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
    val_log, results = validate(model, val_final_dataloader)
    with open(
            f'{opts.output_dir}/results/'
            f'results_{global_step}_final_qa_qar_'
            f'rank{rank}.json', 'w') as f:
        json.dump(results, f)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, global_step)
Exemplo n.º 11
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    opts.n_gpu = n_gpu
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))
    if hvd.rank() != 0:
        LOGGER.disabled = True

    set_random_seed(opts.seed)

    # data loaders
    train_dataloaders = {}
    val_dataloaders = {}
    for target, t_r in zip(opts.targets, opts.targets_ratio):
        train_loaders, val_loaders = build_target_loaders(
            target, t_r,
            opts)  # -> choose which task and get corrsponding task dataloder
        train_dataloaders.update(train_loaders)
        val_dataloaders.update(val_loaders)
    meta_loader = MetaLoader(train_dataloaders,
                             accum_steps=opts.gradient_accumulation_steps,
                             distributed=n_gpu > 1)
    meta_loader = PrefetchLoader(meta_loader)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\
        ".position_embeddings.weight"
    if img_pos_embed_weight_key in checkpoint:
        max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key])
    else:
        max_frm_seq_len = MAX_FRM_SEQ_LEN

    if opts.load_partial_pretrained:
        # from roberta
        model = HeroForPretraining(VideoModelConfig(opts.model_config),
                                   vfeat_dim=VFEAT_DIM,
                                   max_frm_seq_len=max_frm_seq_len,
                                   lw_neg_ctx=opts.lw_neg_ctx,
                                   lw_neg_q=opts.lw_neg_q,
                                   lw_st_ed=0,
                                   ranking_loss_type=opts.ranking_loss_type,
                                   use_hard_negative=False,
                                   hard_pool_size=opts.hard_pool_size,
                                   margin=opts.margin,
                                   use_all_neg=opts.use_all_neg,
                                   drop_svmr_prob=opts.drop_svmr_prob)
        model.load_partial_pretrained(checkpoint,
                                      VFEAT_DIM,
                                      max_frm_seq_len,
                                      skip_layers=opts.skip_layer_loading)
    else:
        # continue training
        model = HeroForPretraining.from_pretrained(
            opts.model_config,
            state_dict=checkpoint,
            vfeat_dim=VFEAT_DIM,
            max_frm_seq_len=max_frm_seq_len,
            lw_neg_ctx=opts.lw_neg_ctx,
            lw_neg_q=opts.lw_neg_q,
            lw_st_ed=0,
            ranking_loss_type=opts.ranking_loss_type,
            use_hard_negative=False,
            hard_pool_size=opts.hard_pool_size,
            margin=opts.margin,
            use_all_neg=opts.use_all_neg,
            drop_svmr_prob=opts.drop_svmr_prob)

    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.fp16,
                                      opt_level='O2')
    restorer = TrainingRestorer(opts, model, optimizer)
    all_gather_list(None)  # sync to prevent slower rank to read training meta
    global_step = restorer.global_step
    TB_LOGGER.global_step = global_step
    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        pbar = NoOp()
        model_saver = NoOp()
        restorer = NoOp()

    if global_step > 0:
        pbar.update(global_step)
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    task2loss = {
        task: RunningMeter(f'loss/{task}')
        for task in train_dataloaders.keys()
    }
    for task in train_dataloaders.keys():
        if task.startswith('vsm'):
            for obj in ('st_ed', 'neg_ctx', 'neg_q'):
                task2loss[f"{task}_{obj}"] = RunningMeter(f'loss/{task}_{obj}')
    model.train()
    n_examples = defaultdict(int)
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    if global_step == 0:
        optimizer.step()
    assert all(global_step == s for s in all_gather_list(global_step))
    for step, (task, batch) in enumerate(meta_loader):
        LOGGER.debug(f"Task: {task}")

        # hard negative in VSM
        if len(opts.hard_negtiave_start_step) > 0:
            for i, hn_step in enumerate(opts.hard_negtiave_start_step):
                if global_step >= hn_step and hn_step != -1:
                    model.set_hard_negative(True, opts.hard_pool_size[i],
                                            opts.hard_neg_weights[i])

        # start-end loss
        if opts.train_span_start_step != -1 and\
                global_step >= opts.train_span_start_step:
            model.set_train_st_ed(opts.lw_st_ed)

        train_task = task.split('_')[0]
        n_examples[task] += opts.train_batch_size

        loss = model(batch, task=train_task, compute_loss=True)
        if train_task == 'vsm':
            loss_st_ed, loss_neg_ctx, loss_neg_q = loss
            loss = loss_st_ed + loss_neg_ctx + loss_neg_q
            for n, ls, w in (('st_ed', loss_st_ed, opts.lw_st_ed),
                             ('neg_ctx', loss_neg_ctx, opts.lw_neg_ctx),
                             ('neg_q', loss_neg_q, opts.lw_neg_q)):
                ls = ls.item()
                if w:
                    ls /= w
                task2loss[f'{task}_{n}'](ls)
        elif train_task == "mffr":
            loss = torch.sqrt(loss.sum(dim=1))

        loss = loss.mean()
        task2loss[task](loss.item())

        delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
        with amp.scale_loss(loss,
                            optimizer,
                            delay_unscale=delay_unscale,
                            loss_id=task2scaler[task]) as scaled_loss:
            scaled_loss.backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                LOGGER.debug("before reduce grad")
                all_reduce_and_rescale_tensors(grads, float(1))
                LOGGER.debug("after reduce grad")

        if (step + 1) % opts.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            lr_this_step = get_lr_sched(global_step, opts)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

            # log loss
            # NOTE: only consider rank 0 for speed
            TB_LOGGER.log_scaler_dict({
                ll.name: ll.val
                for ll in task2loss.values() if ll.val is not None
            })
            TB_LOGGER.step()

            LOGGER.debug("before norm grad")
            # update model params
            if opts.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            opts.grad_norm)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            LOGGER.debug("after norm grad")
            LOGGER.debug("before optim step")
            optimizer.step()
            optimizer.zero_grad()
            pbar.update(1)
            LOGGER.debug("after optim step")

            if global_step % 100 == 0:
                LOGGER.debug("after gather stats")
                # monitor training throughput
                LOGGER.info('-------------------------------------------')
                LOGGER.info(f'Step {global_step}:')
                for t in train_dataloaders.keys():
                    tot_ex = sum(all_gather_list(n_examples[t]))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{t}: {tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
                                         global_step)
                LOGGER.debug("after gather stats")

            if global_step % opts.valid_steps == 0:
                LOGGER.info('===========================================')
                LOGGER.info(f"Step {global_step}: start running validation")
                validate(model, val_dataloaders, opts)
                LOGGER.info('===========================================')
                model_saver.save(model, global_step)

            # step restorer in the end to prevent missing validation checkpoint
            restorer.step()
        if global_step >= opts.num_train_steps:
            break

    LOGGER.info('===========================================')
    if global_step % opts.valid_steps != 0:
        LOGGER.info('===========================================')
        LOGGER.info(f"Step {global_step}: start running validation")
        validate(model, val_dataloaders, opts)
        LOGGER.info('===========================================')
        model_saver.save(model, global_step)
Exemplo n.º 12
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, "log"))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, "ckpt"))
        add_log_to_file(join(opts.output_dir, "log", "log.txt"))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, "results_val"))
        os.makedirs(join(opts.output_dir, "results_test"))
        os.makedirs(join(opts.output_dir, "results_train"))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(
        opts.train_img_dbs), "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets_t = []
    train_datasets_i = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets_t.append(
            ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size))
        train_datasets_i.append(
            ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size))
    train_dataset_t = ConcatDataset(train_datasets_t)
    train_dataset_i = ConcatDataset(train_datasets_i)
    train_dataloader_t = build_dataloader(train_dataset_t, itm_rank_hn_collate,
                                          True, opts)
    train_dataloader_i = build_dataloader(train_dataset_i, itm_rank_hn_collate,
                                          True, opts)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrievalHardNeg.from_pretrained(
        opts.model_config,
        state_dict=checkpoint,
        img_dim=IMG_DIM,
        margin=opts.margin,
        hard_size=opts.hard_neg_size,
    )
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level="O2")

    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d",
                sum(all_gather_list(len(train_dataset_t))))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter("loss")
    model.train()

    global_step = 0
    step = 0
    n_examples = 0
    n_hard_ex = 0
    start = time()
    train_iter_i = iter(train_dataloader_i)
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for batch in train_dataloader_t:

            # hard text from image
            try:
                batch_i = next(train_iter_i)
            except StopIteration:
                train_iter_i = iter(train_dataloader_i)
                batch_i = next(train_iter_i)
            n_examples += batch_i["attn_masks"].size(0)
            loss = model(batch_i, sample_from="i", compute_loss=True)
            n_hard_ex += loss.numel()
            loss = loss.mean() / opts.train_batch_size
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=True) as scaled_loss:
                scaled_loss.backward()

            # hard image from text
            n_examples += batch["attn_masks"].size(0)
            loss = model(batch, sample_from="t", compute_loss=True)
            n_hard_ex += loss.numel()
            # NOTE we use gradient accumulation to implemented train_batch_size
            loss = loss.mean() / opts.train_batch_size

            step += 1
            delay_unscale = step % opts.train_batch_size != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if step % opts.train_batch_size == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr_this_step
                TB_LOGGER.add_scalar("lr", lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar("loss", running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar("grad_norm", grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f"------------Step {global_step}-------------")
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    tot_hn = sum(all_gather_list(n_hard_ex))
                    hn_per_sec = int(tot_hn / (time() - start))
                    LOGGER.info(f"{tot_ex} ({tot_hn}) examples (hard) "
                                f"trained at {ex_per_sec} ({hn_per_sec}) ex/s")
                    TB_LOGGER.add_scalar("perf/ex_per_s", ex_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar("perf/hn_per_s", hn_per_sec,
                                         global_step)
                    LOGGER.info(f"-------------------------------------------")

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break

    pbar.close()
    # final validation
    val_log = validate(model, val_dataloader)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f"{global_step}_final")

    # evaluation
    for split, loader in [("val", eval_loader_val),
                          ("test", eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Exemplo n.º 13
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
                f"{opts.train_img_db}")
    if 'paired' in opts.model:
        DatasetCls = Nlvr2PairedDataset
        EvalDatasetCls = Nlvr2PairedEvalDataset
        collate_fn = nlvr2_paired_collate
        eval_collate_fn = nlvr2_paired_eval_collate
        if opts.model == 'paired':
            ModelCls = UniterForNlvr2Paired
        elif opts.model == 'paired-attn':
            ModelCls = UniterForNlvr2PairedAttn
        else:
            raise ValueError('unrecognized model type')
    elif opts.model == 'triplet':
        DatasetCls = Nlvr2TripletDataset
        EvalDatasetCls = Nlvr2TripletEvalDataset
        ModelCls = UniterForNlvr2Triplet
        collate_fn = nlvr2_triplet_collate
        eval_collate_fn = nlvr2_triplet_eval_collate
    else:
        raise ValueError('unrecognized model type')

    # data loaders
    train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db,
                                         opts.train_batch_size, True,
                                         DatasetCls, collate_fn, opts)
    val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db,
                                       opts.val_batch_size, False,
                                       EvalDatasetCls, eval_collate_fn, opts)
    test_dataloader = create_dataloader(opts.test_img_db, opts.test_txt_db,
                                        opts.val_batch_size, False,
                                        EvalDatasetCls, eval_collate_fn, opts)

    # Prepare model
    if opts.checkpoint:
        ckpt = torch.load(opts.checkpoint)
        checkpoint = {k.replace('bert', 'uniter'): v for k, v in ckpt.items()} 
    else:
        checkpoint = {}

    model = ModelCls.from_pretrained(opts.model_config, state_dict=checkpoint,
                                     img_dim=IMG_DIM)
    model.init_type_embedding()
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model, optimizer,
                                      enabled=opts.fp16, opt_level='O2')

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataloader.dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            targets = batch['targets']
            n_examples += targets.size(0)

            # ============================ Code for adversarial training =============
            if opts.adv_training:
                # initialize delta
                txt_embeds_init = model.uniter.embeddings.word_embeddings(
                    batch['input_ids'])
                img_embeds_init = batch['img_feat']

                # for simplicity, we initialize the delta as zero vectors, which performs
                # very simliar as initializing randomly using norm or uniform distributions
                txt_delta = torch.zeros_like(txt_embeds_init)
                img_delta = torch.zeros_like(img_embeds_init)

                # calculate the prob. scores for clean samples
                gt_answer_scores = model(batch, compute_loss=False)
                gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                # the main loop
                for astep in range(opts.adv_steps):
                    # (0) forward
                    if opts.adv_modality == ["text"]:
                        txt_delta.requires_grad_()
                        img_delta = torch.zeros_like(img_embeds_init)
                    elif opts.adv_modality == ["image"]:
                        img_delta.requires_grad_()
                        txt_delta = torch.zeros_like(txt_embeds_init)
                    else:
                        txt_delta.requires_grad_()
                        img_delta.requires_grad_()

                    if "alter" not in opts.adv_modality:
                        answer_scores = model(batch, adv_training = True,
                            adv_modality = opts.adv_modality, adv_delta_txt = txt_delta,
                            adv_delta_img = img_delta, compute_loss=False)

                        # BCE loss
                        bce_loss = F.cross_entropy(answer_scores, 
                            batch['targets'], reduction='none')
                        bce_loss = bce_loss.mean()

                        # KL loss
                        answer_prob = F.softmax(answer_scores, dim=1)
                        answer_logprob = F.log_softmax(answer_scores, dim=1)
                        kl_loss = F.kl_div(answer_logprob,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob,reduction='none')
                        kl_loss = kl_loss.sum(1).mean()

                        # (1) backward
                        loss = (bce_loss + opts.adv_kl_weight * kl_loss) / opts.adv_steps
                    else:
                        answer_scores_1 = model(batch, adv_training = True, 
                            adv_modality = ["text"], adv_delta_txt = txt_delta, 
                            adv_delta_img = None, compute_loss=False)

                        bce_loss_1 = F.cross_entropy(answer_scores_1, 
                            batch['targets'], reduction='none')
                        bce_loss_1 = bce_loss_1.mean()

                        answer_scores_2 = model(batch, adv_training = True, 
                            adv_modality = ["image"], adv_delta_txt = None, 
                            adv_delta_img = img_delta, compute_loss=False)

                        bce_loss_2 = F.cross_entropy(answer_scores_2, 
                            batch['targets'], reduction='none')
                        bce_loss_2 = bce_loss_2.mean()

                        # KL loss
                        answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                        answer_logprob_1 = F.log_softmax(answer_scores_1, dim=1)
                        answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                        answer_logprob_2 = F.log_softmax(answer_scores_2, dim=1)
                    
                        kl_loss_1 = F.kl_div(answer_logprob_1,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob_1,reduction='none')
                        kl_loss_1 = kl_loss_1.sum(1).mean()

                        kl_loss_2 = F.kl_div(answer_logprob_2,gt_answer_prob,reduction='none') + \
                                    F.kl_div(gt_answer_logprob,answer_prob_2,reduction='none')
                        kl_loss_2 = kl_loss_2.sum(1).mean()
                    
                        # (1) backward
                        loss = (bce_loss_1 + bce_loss_2 + opts.adv_kl_weight * (kl_loss_1+kl_loss_2)) / (opts.adv_steps*2)

                    delay_unscale = ((step+1) % opts.gradient_accumulation_steps != 0) or ((astep+1) % opts.adv_steps != 0)
                    with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
                                        ) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling to make sure every process uses
                            # the same gradient scale
                            grads = [p.grad.data for p in model.parameters()
                                     if p.requires_grad and p.grad is not None]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())

                    if astep == opts.adv_steps - 1:
                        # further updates on delta
                        break

                    # (2) get gradient on delta
                    # fix fp16 problem
                    amp_scale = scaled_loss.item() // loss.item()
                    if "text" in opts.adv_modality:
                        txt_delta_grad = txt_delta.grad.clone().detach().float() / amp_scale
                    if "image" in opts.adv_modality:
                        img_delta_grad = img_delta.grad.clone().detach().float() / amp_scale

                    # (3) update and clip for txt delta
                    if "text" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(txt_delta_grad.view(txt_delta_grad.size(0), -1), dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt * txt_delta_grad / denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(txt_delta.view(txt_delta.size(0), -1), p=2, dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm).to(txt_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm * exceed_mask + (1-exceed_mask)).view(-1, 1, 1)
                                txt_delta = (txt_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(txt_delta_grad.view(txt_delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            txt_delta_step = (opts.adv_lr_txt * txt_delta_grad / denorm).to(txt_delta)
                            txt_delta = (txt_delta + txt_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                txt_delta = torch.clamp(txt_delta, -opts.adv_max_norm, opts.adv_max_norm).detach()

                    # (4) update and clip for image delta
                    if "image" in opts.adv_modality:
                        if opts.norm_type == "l2":
                            denorm = torch.norm(img_delta_grad.view(img_delta_grad.size(0), -1), dim=1).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img * img_delta_grad / denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                delta_norm = torch.norm(img_delta.view(img_delta.size(0), -1), p=2, dim=1).detach()
                                exceed_mask = (delta_norm > opts.adv_max_norm).to(img_embeds_init)
                                reweights = (opts.adv_max_norm / delta_norm * exceed_mask + (1-exceed_mask)).view(-1, 1, 1)
                                img_delta = (img_delta * reweights).detach()
                        elif opts.norm_type == "linf":
                            denorm = torch.norm(img_delta_grad.view(img_delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1)
                            denorm = torch.clamp(denorm, min=1e-8)
                            img_delta_step = (opts.adv_lr_img * img_delta_grad / denorm).to(img_delta)
                            img_delta = (img_delta + img_delta_step).detach()
                            if opts.adv_max_norm > 0:
                                img_delta = torch.clamp(img_delta, -opts.adv_max_norm, opts.adv_max_norm).detach()

            else:
                loss = model(batch, compute_loss=True)
                loss = loss.mean()
                delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
                with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
                                    ) as scaled_loss:
                    scaled_loss.backward()
                    if not delay_unscale:
                        # gather gradients from every processes
                        # do this before unscaling to make sure every process uses
                        # the same gradient scale
                        grads = [p.grad.data for p in model.parameters()
                                 if p.requires_grad and p.grad is not None]
                        all_reduce_and_rescale_tensors(grads, float(1))

                running_loss(loss.item())

            # ============================ End ==========================

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time()-start))
                    LOGGER.info(f'Step {global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s',
                                         ex_per_sec, global_step)

                if global_step % opts.valid_steps == 0:
                    for split, loader in [('val', val_dataloader),
                                          ('test', test_dataloader)]:
                        LOGGER.info(f"Step {global_step}: start running "
                                    f"validation on {split} split...")
                        log, results = validate(model, loader, split)
                        with open(f'{opts.output_dir}/results/'
                                  f'{split}_results_{global_step}_'
                                  f'rank{rank}.csv', 'w') as f:
                            for id_, ans in results:
                                f.write(f'{id_},{ans}\n')
                        TB_LOGGER.log_scaler_dict(log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        for split, loader in [('val', val_dataloader),
                              ('test', test_dataloader)]:
            LOGGER.info(f"Step {global_step}: start running "
                        f"validation on {split} split...")
            log, results = validate(model, loader, split)
            with open(f'{opts.output_dir}/results/'
                      f'{split}_results_{global_step}_'
                      f'rank{rank}.csv', 'w') as f:
                for id_, ans in results:
                    f.write(f'{id_},{ans}\n')
            TB_LOGGER.log_scaler_dict(log)
        model_saver.save(model, global_step)
Exemplo n.º 14
0
def main(opts, checkpoint_dir=None, tuning=False):
    from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
    with logger.catch(reraise=True):
        logger.info(f"{opts}")
        if isinstance(opts, dict):
            opts = edict(opts)

        hvd.init()
        n_gpu = hvd.size()
        device = torch.device("cuda", hvd.local_rank())
        torch.cuda.set_device(hvd.local_rank())
        rank = hvd.rank()
        opts.rank = rank
        LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                    "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                                  opts.fp16))

        if opts.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, "
                "should be >= 1".format(opts.gradient_accumulation_steps))

        set_random_seed(opts.seed)
        """
        # load DBs and image dirs
        """
        all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                     opts.num_bb, opts.compressed_db)

        # train
        LOGGER.info(f"Loading Train Dataset "
                    f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
        train_datasets = []
        for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
            img_db = all_img_dbs[img_path]
            txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
            train_datasets.append(MemeDataset(1, txt_db, img_db))
        train_dataset = ConcatDatasetWithLens(train_datasets)
        train_dataloader = build_dataloader(train_dataset, meme_collate, True,
                                            opts)

        # val
        LOGGER.info(
            f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
        val_img_db = all_img_dbs[opts.val_img_db]
        val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
        val_dataset = MemeEvalDataset(1, val_txt_db, val_img_db)
        val_dataloader = build_dataloader(val_dataset,
                                          meme_eval_itm_ot_collate, False,
                                          opts)

        # test_img_db = val_img_db
        # test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
        # test_dataset = MemeEvalDataset(1, test_txt_db, test_img_db)
        # test_dataloader = build_dataloader(test_dataset, meme_eval_collate,
        #                                 False, opts)
        """
        # Prepare model
        """
        if opts.checkpoint:
            checkpoint = torch.load(opts.checkpoint)
        else:
            checkpoint = {}

        all_dbs = opts.train_txt_dbs + [opts.val_txt_db]

        model = UniterForITM.from_pretrained(opts.model_config,
                                             checkpoint,
                                             img_dim=IMG_DIM,
                                             num_answer=1)
        model.to(device)
        # make sure every process has same model parameters in the beginning
        broadcast_tensors([p.data for p in model.parameters()], 0)
        set_dropout(model, opts.dropout)
        """
        # Prepare optimizer
        """
        optimizer = build_optimizer(model, opts)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          enabled=opts.fp16,
                                          opt_level='O2')
        global_step = 0
        if rank == 0:
            save_training_meta(opts)
            TB_LOGGER.create(join(opts.output_dir, 'log'))
            pbar = tqdm(total=opts.num_train_steps)
            model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
            # json.dump(ans2label,
            #           open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
            os.makedirs(join(opts.output_dir, 'results'),
                        exist_ok=tuning)  # store VQA predictions
            add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        else:
            LOGGER.disabled = True
            pbar = NoOp()
            model_saver = NoOp()

        LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
        LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
        LOGGER.info("  Batch size = %d", opts.train_batch_size)
        LOGGER.info("  Accumulate steps = %d",
                    opts.gradient_accumulation_steps)
        LOGGER.info("  Num steps = %d", opts.num_train_steps)

        running_loss = RunningMeter('loss')
        model.train()
        n_examples = 0
        n_epoch = 0

        if checkpoint_dir is not None and tuning:
            checkpoint = os.path.join(checkpoint_dir, "checkpoint")
            (model_state, optimizer_state, n_epoch,
             n_examples) = torch.load(checkpoint)
            model.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

            LOGGER.info(
                f"***** Resume from ray tune checkpoint : {checkpoint_dir} *****"
            )
            LOGGER.info("  n_examples = %d", n_examples)
            LOGGER.info("  n_epoch = %d", n_epoch)

            # shutil.rmtree(checkpoint_dir)

        start = time()
        # quick hack for amp delay_unscale bug
        optimizer.zero_grad()
        optimizer.step()
        while True:
            for step, batch in enumerate(train_dataloader):
                if global_step > 2000:
                    logger.error('Force stop at global step 2000')
                    sys.exit(0)
                n_examples += batch['input_ids'].size(0)

                if opts.adv_training:
                    # NOTE: reverse label like what we do in UniterForITM
                    targets = batch['targets']
                    targets = (targets > 0.5).long()
                    targets = torch.abs(targets - 1)
                    batch['targets'] = targets

                    # initialize delta
                    txt_embeds_init = model.uniter.embeddings.word_embeddings(
                        batch['input_ids'])
                    img_embeds_init = batch['img_feat']

                    # for simplicity, we initialize the delta as zero vectors, which performs
                    # very simliar as initializing randomly using norm or uniform distributions
                    txt_delta = torch.zeros_like(txt_embeds_init)
                    img_delta = torch.zeros_like(img_embeds_init)

                    # calculate the prob. scores for clean samples
                    gt_answer_scores = model(batch, compute_loss=False)
                    gt_answer_prob = F.softmax(gt_answer_scores, dim=1)
                    gt_answer_logprob = F.log_softmax(gt_answer_scores, dim=1)

                    # the main loop
                    for astep in range(opts.adv_steps):
                        # (0) forward
                        if opts.adv_modality == ["text"]:
                            txt_delta.requires_grad_()
                            img_delta = torch.zeros_like(img_embeds_init)
                        elif opts.adv_modality == ["image"]:
                            img_delta.requires_grad_()
                            txt_delta = torch.zeros_like(txt_embeds_init)
                        else:
                            txt_delta.requires_grad_()
                            img_delta.requires_grad_()

                        if "alter" not in opts.adv_modality:
                            answer_scores = model(
                                batch,
                                adv_training=True,
                                adv_modality=opts.adv_modality,
                                adv_delta_txt=txt_delta,
                                adv_delta_img=img_delta,
                                compute_loss=False)

                            # CE loss
                            ce_loss = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            # KL loss
                            answer_prob = F.softmax(answer_scores, dim=1)
                            answer_logprob = F.log_softmax(answer_scores,
                                                           dim=1)
                            kl_loss = F.kl_div(
                                answer_logprob, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob,
                                    reduction='none')
                            kl_loss = kl_loss.mean()

                            # (1) backward
                            loss = (ce_loss + opts.adv_kl_weight *
                                    kl_loss) / opts.adv_steps
                        else:
                            answer_scores_1 = model(batch,
                                                    adv_training=True,
                                                    adv_modality=["text"],
                                                    adv_delta_txt=txt_delta,
                                                    adv_delta_img=None,
                                                    compute_loss=False)

                            # CE loss
                            ce_loss_1 = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            answer_scores_2 = model(batch,
                                                    adv_training=True,
                                                    adv_modality=["image"],
                                                    adv_delta_txt=None,
                                                    adv_delta_img=img_delta,
                                                    compute_loss=False)

                            # CE loss
                            ce_loss_2 = F.cross_entropy(
                                answer_scores,
                                batch['targets'].squeeze(-1),
                                reduction='mean')

                            # KL loss
                            answer_prob_1 = F.softmax(answer_scores_1, dim=1)
                            answer_logprob_1 = F.log_softmax(answer_scores_1,
                                                             dim=1)
                            answer_prob_2 = F.softmax(answer_scores_2, dim=1)
                            answer_logprob_2 = F.log_softmax(answer_scores_2,
                                                             dim=1)
                            kl_loss_1 = F.kl_div(
                                answer_logprob_1, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob_1,
                                    reduction='none')
                            kl_loss_1 = kl_loss_1.mean()
                            kl_loss_2 = F.kl_div(
                                answer_logprob_2, gt_answer_prob, reduction='none') + \
                                F.kl_div(
                                    gt_answer_logprob, answer_prob_2,
                                    reduction='none')
                            kl_loss_2 = kl_loss_2.mean()

                            # (1) backward
                            loss = (
                                ce_loss_1 + ce_loss_2 + opts.adv_kl_weight *
                                (kl_loss_1 + kl_loss_2)) / (opts.adv_steps * 2)

                        delay_unscale = (
                            (step + 1) % opts.gradient_accumulation_steps !=
                            0) or ((astep + 1) % opts.adv_steps != 0)
                        with amp.scale_loss(
                                loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                            if not delay_unscale:
                                # gather gradients from every processes
                                # do this before unscaling
                                # to make sure every process uses
                                # the same gradient scale
                                grads = [
                                    p.grad.data for p in model.parameters()
                                    if p.requires_grad and p.grad is not None
                                ]
                                all_reduce_and_rescale_tensors(grads, float(1))

                        running_loss(loss.item())

                        if astep == opts.adv_steps - 1:
                            # further updates on delta
                            break

                        # (2) get gradient on delta
                        # fix fp16 problem
                        amp_scale = scaled_loss.item() // loss.item()
                        if "text" in opts.adv_modality:
                            txt_delta_grad = txt_delta.grad.clone().detach()
                            txt_delta_grad = txt_delta_grad.float() / amp_scale
                        if "image" in opts.adv_modality:
                            img_delta_grad = img_delta.grad.clone().detach()
                            img_delta_grad = img_delta_grad.float() / amp_scale

                        # (3) update and clip for txt delta
                        if "text" in opts.adv_modality:
                            if opts.norm_type == "l2":
                                denorm = torch.norm(txt_delta_grad.view(
                                    txt_delta_grad.size(0), -1),
                                                    dim=1).view(-1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                txt_delta_step = (opts.adv_lr_txt *
                                                  txt_delta_grad /
                                                  denorm).to(txt_delta)
                                txt_delta = (txt_delta +
                                             txt_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    delta_norm = torch.norm(txt_delta.view(
                                        txt_delta.size(0), -1),
                                                            p=2,
                                                            dim=1).detach()
                                    exceed_mask = (
                                        delta_norm >
                                        opts.adv_max_norm).to(txt_embeds_init)
                                    reweights = (opts.adv_max_norm /
                                                 delta_norm * exceed_mask +
                                                 (1 - exceed_mask)).view(
                                                     -1, 1, 1)
                                    txt_delta = (txt_delta *
                                                 reweights).detach()
                            elif opts.norm_type == "linf":
                                denorm = torch.norm(txt_delta_grad.view(
                                    txt_delta_grad.size(0), -1),
                                                    dim=1,
                                                    p=float("inf")).view(
                                                        -1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                txt_delta_step = (opts.adv_lr_txt *
                                                  txt_delta_grad /
                                                  denorm).to(txt_delta)
                                txt_delta = (txt_delta +
                                             txt_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    txt_delta = torch.clamp(
                                        txt_delta, -opts.adv_max_norm,
                                        opts.adv_max_norm).detach()

                        # (4) update and clip for image delta
                        if "image" in opts.adv_modality:
                            if opts.norm_type == "l2":
                                denorm = torch.norm(img_delta_grad.view(
                                    img_delta_grad.size(0), -1),
                                                    dim=1).view(-1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                img_delta_step = (opts.adv_lr_img *
                                                  img_delta_grad /
                                                  denorm).to(img_delta)
                                img_delta = (img_delta +
                                             img_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    delta_norm = torch.norm(img_delta.view(
                                        img_delta.size(0), -1),
                                                            p=2,
                                                            dim=1).detach()
                                    exceed_mask = (
                                        delta_norm >
                                        opts.adv_max_norm).to(img_embeds_init)
                                    reweights = (opts.adv_max_norm /
                                                 delta_norm * exceed_mask +
                                                 (1 - exceed_mask)).view(
                                                     -1, 1, 1)
                                    img_delta = (img_delta *
                                                 reweights).detach()
                            elif opts.norm_type == "linf":
                                denorm = torch.norm(img_delta_grad.view(
                                    img_delta_grad.size(0), -1),
                                                    dim=1,
                                                    p=float("inf")).view(
                                                        -1, 1, 1)
                                denorm = torch.clamp(denorm, min=1e-8)
                                img_delta_step = (opts.adv_lr_img *
                                                  img_delta_grad /
                                                  denorm).to(img_delta)
                                img_delta = (img_delta +
                                             img_delta_step).detach()
                                if opts.adv_max_norm > 0:
                                    img_delta = torch.clamp(
                                        img_delta, -opts.adv_max_norm,
                                        opts.adv_max_norm).detach()
                else:
                    loss = model(batch, compute_loss=True)
                    loss = loss.mean()
                    delay_unscale = (step +
                                     1) % opts.gradient_accumulation_steps != 0
                    with amp.scale_loss(
                            loss, optimizer,
                            delay_unscale=delay_unscale) as scaled_loss:
                        scaled_loss.backward()
                        if not delay_unscale:
                            # gather gradients from every processes
                            # do this before unscaling to make sure every process uses
                            # the same gradient scale
                            grads = [
                                p.grad.data for p in model.parameters()
                                if p.requires_grad and p.grad is not None
                            ]
                            all_reduce_and_rescale_tensors(grads, float(1))

                    running_loss(loss.item())
                """
                loss compute end
                log & step start
                """

                if (step + 1) % opts.gradient_accumulation_steps == 0:
                    global_step += 1

                    # learning rate scheduling
                    lr_this_step = get_lr_sched(global_step, opts)
                    for i, param_group in enumerate(optimizer.param_groups):
                        if i == 0 or i == 1:
                            param_group['lr'] = lr_this_step * opts.lr_mul
                        elif i == 2 or i == 3:
                            param_group['lr'] = lr_this_step
                        else:
                            raise ValueError()
                    TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                    # log loss
                    # NOTE: not gathered across GPUs for efficiency
                    TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                    TB_LOGGER.step()

                    # update model params
                    if opts.grad_norm != -1:
                        grad_norm = clip_grad_norm_(
                            amp.master_params(optimizer), opts.grad_norm)
                        TB_LOGGER.add_scalar('grad_norm', grad_norm,
                                             global_step)
                    optimizer.step()
                    optimizer.zero_grad()
                    pbar.update(1)

                    if global_step % 100 == 0:
                        # monitor training throughput
                        LOGGER.info(
                            f'============Step {global_step}=============')
                        tot_ex = sum(all_gather_list(n_examples))
                        ex_per_sec = int(tot_ex / (time() - start))
                        LOGGER.info(f'{tot_ex} examples trained at '
                                    f'{ex_per_sec} ex/s')
                        TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                             global_step)
                        LOGGER.info(
                            f'===========================================')

                    if global_step % opts.valid_steps == 0:
                        val_log, results = validate(model, val_dataloader,
                                                    None)

                        with open(
                                f'{opts.output_dir}/results/'
                                f'results_{global_step}_'
                                f'rank{rank}.json', 'w') as f:
                            json.dump(results, f)
                        pd.DataFrame.from_dict(results).to_csv(
                            f'{opts.output_dir}/results/'
                            f'results_{global_step}_'
                            f'rank{rank}.csv',
                            index=False)

                        # _, test_results = test(model, test_dataloader, global_step)
                        # pd.DataFrame.from_dict(test_results).to_csv(
                        #     f'{opts.output_dir}/results/'
                        #     f'test_{global_step}.csv',
                        #     index=False)

                        TB_LOGGER.log_scaler_dict(val_log)
                        model_saver.save(model, global_step)

                        if tuning:
                            with tune.checkpoint_dir(
                                    step=n_epoch) as checkpoint_dir:
                                logger.info(
                                    f'***** Save tune ckpt: {checkpoint_dir} *****'
                                )
                                path = os.path.join(checkpoint_dir,
                                                    "checkpoint")
                                torch.save((model.state_dict(),
                                            optimizer.state_dict(), n_epoch,
                                            n_examples), path)
                            tune.report(
                                loss=(val_log['valid/loss']),
                                accuracy=val_log['valid/acc'],
                                auroc=val_log['valid/auroc'],
                            )
                if global_step >= opts.num_train_steps:
                    break
            if global_step >= opts.num_train_steps:
                break
            n_epoch += 1
            LOGGER.info(f"finished {n_epoch} epochs")
            """
            END of training loop
            """

        if opts.num_train_steps % opts.valid_steps != 0:
            val_log, results = validate(model, val_dataloader, None)
            with open(
                    f'{opts.output_dir}/results/'
                    f'results_{global_step}_'
                    f'rank{rank}.json', 'w') as f:
                json.dump(results, f)
            pd.DataFrame.from_dict(results).to_csv(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.csv',
                index=False)
            TB_LOGGER.log_scaler_dict(val_log)
            model_saver.save(model, global_step)
Exemplo n.º 15
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank

    
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))
    device = torch.device("cuda:1")
    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        os.makedirs(join(opts.output_dir, 'ckpt'))
        save_training_meta(opts)
        # TB_LOGGER.create(join(opts.output_dir, 'log'))
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'))
        os.makedirs(join(opts.output_dir, 'results_test'))
        os.makedirs(join(opts.output_dir, 'results_train'))
    else:
        LOGGER.disabled = True
        model_saver = NoOp()

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    train_dataset = MemeAIDataset(json_path = '/home/data/meme_json/train.json',
                                    npz_folder = '/home/data/faster_cnn_feature/', 
                                    mode = 'train')
    train_loader =  DataLoader(train_dataset, 
                                    batch_size = opts.train_batch_size, 
                                    shuffle = True, 
                                    num_workers = opts.n_workers,
                                    collate_fn=collate_fn)
    train_loader = PrefetchLoader(train_loader)

    # val
    val_dataset = MemeAIDataset(json_path = '/home/data/meme_json/dev.json',
                                    npz_folder = '/home/data/faster_cnn_feature/', 
                                    mode = 'val')
    val_loader =  DataLoader(val_dataset, 
                                    batch_size = opts.inf_minibatch_size, 
                                    shuffle = False, 
                                    num_workers = opts.n_workers,
                                    collate_fn=collate_fn)
    val_loader = PrefetchLoader(val_loader)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = Meme.from_pretrained(
        opts.model_config, state_dict=checkpoint,
        img_dim=IMG_DIM)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)

    # make sure every process has same model parameters in the beginning
    # broadcast_tensors([p.data for p in model.parameters()], 0)
    # set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model, optimizer,
                                      enabled=opts.fp16, opt_level='O2')

    global_step = 0
    # LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    # LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    # while True:
    for epoch in range(opts.epoch):
        print('epoch {}/ {}'.format(epoch, opts.epoch))
        pbar = tqdm(total=len(train_loader))

        model.train()
        preds = None
        gt = None

        for step, batch in enumerate(train_loader):
            x = batch[0]
            y = batch[1]
            n_examples += x['input_ids'].size(0)

            pred = model(x)

            if preds is None:

                preds = torch.sigmoid(pred)
                gt = y
            else:
                preds = torch.cat((preds, torch.sigmoid(pred)), dim = 0)
                gt = torch.cat((gt, y), dim = 0)


            loss = F.binary_cross_entropy(torch.sigmoid(pred), y)

            delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
                                ) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [p.grad.data for p in model.parameters()
                             if p.requires_grad and p.grad is not None]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()


        global_step += 1

        # learning rate scheduling
        lr_this_step = get_lr_sched(global_step, opts)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_this_step
        TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

        # log loss
        # NOTE: not gathered across GPUs for efficiency
        TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
        TB_LOGGER.step()

        # update model params
        if opts.grad_norm != -1:
            grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                        opts.grad_norm)
            TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
        optimizer.step()
        optimizer.zero_grad()

        with torch.no_grad():
            preds = preds.cpu().numpy().reshape(len(preds), )
            gt = gt.cpu().numpy()
            roc = roc_auc_score(gt, preds)
            acc = accuracy_score(gt, np.around(preds)) 
        train_log = {'train/roc': roc, 'train/acc': acc}
        TB_LOGGER.log_scaler_dict({f"train/{k}": v for k, v in train_log.items()})

        # monitor training throughput

        val_log = validate(model, val_loader)
        TB_LOGGER.log_scaler_dict({f"valid/{k}": v for k, v in val_log.items()})

        LOGGER.info(train_log)
        LOGGER.info(val_log)

        model_saver.save(model, global_step)

        pbar.close()
Exemplo n.º 16
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hasattr(opts, 'ans2label_path'):
        ans2label = json.load(open(opts.ans2label_path, 'r', encoding='utf-8'))
    else:
        ans2label = json.load(
            open(f'{dirname(abspath(__file__))}'
                 f'/utils/ans2label.json'))
    label2ans = {label: ans for ans, label in ans2label.items()}

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db))
    train_dataset = ConcatDatasetWithLens(train_datasets)
    train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts)
    # val
    LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db)
    val_dataloader = build_dataloader(val_dataset, vqa_eval_collate, False,
                                      opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint, map_location=device)
    else:
        checkpoint = {}

    all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)
    model = UniterForVisualQuestionAnswering.from_pretrained(
        opts.model_config,
        checkpoint,
        img_dim=IMG_DIM,
        num_answer=len(ans2label))
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')
    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        json.dump(ans2label,
                  open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
        os.makedirs(join(opts.output_dir, 'results'))  # store VQA predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean() * batch['targets'].size(1)  # instance-leval bce
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'===========================================')

                if global_step % opts.valid_steps == 0:
                    val_log, results = validate(model, val_dataloader,
                                                label2ans)
                    with open(
                            f'{opts.output_dir}/results/'
                            f'results_{global_step}_'
                            f'rank{rank}.json', 'w') as f:
                        json.dump(results, f)
                    TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
    if opts.num_train_steps % opts.valid_steps != 0:
        val_log, results = validate(model, val_dataloader, label2ans)
        with open(
                f'{opts.output_dir}/results/'
                f'results_{global_step}_'
                f'rank{rank}.json', 'w') as f:
            json.dump(results, f)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)
Exemplo n.º 17
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(
                    device, n_gpu, hvd.rank(), opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(args.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets]
               for dset in datasets for db in dset['db']]

    tokenizer = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
    assert all(tokenizer == json.load(open(f'{db}/meta.json'))['bert']
               for db in all_dbs)

    # build data loaders
    train_dataloaders, all_img_dbs = create_dataloaders(
        opts.train_datasets, True, opts)
    val_dataloaders, _ = create_dataloaders(
        opts.val_datasets, False, opts, all_img_dbs)
    meta_loader = MetaLoader(train_dataloaders,
                             accum_steps=opts.gradient_accumulation_steps,
                             distributed=n_gpu > 1)
    meta_loader = PrefetchLoader(meta_loader)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}
    model = UniterForPretraining.from_pretrained(
        opts.model_config, checkpoint,
        img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM)
    model.to(device)
    model.train()
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
    model, optimizer = amp.initialize(model, optimizer,
                                      num_losses=len(task2scaler),
                                      enabled=opts.fp16, opt_level='O2')

    global_step = 0
    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    # to compute training statistics
    task2loss = {task: RunningMeter(f'loss/{task}')
                 for task in train_dataloaders.keys()}
    # ITM w/ OT
    if opts.itm_ot_lambda > 0:
        for task in train_dataloaders.keys():
            if task.startswith('itm'):
                task2loss[f'{task}_xe'] = RunningMeter(f'loss/{task}_xe')
                task2loss[f'{task}_ot'] = RunningMeter(f'loss/{task}_ot')
                task2loss[f'{task}_ot_pos'] = RunningMeter(
                    f'loss/{task}_ot_pos')
                task2loss[f'{task}_ot_neg'] = RunningMeter(
                    f'loss/{task}_ot_neg')

    n_examples = defaultdict(int)
    n_in_units = defaultdict(int)
    n_loss_units = defaultdict(int)
    grad_norm = 0

    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    for step, (name, batch) in enumerate(meta_loader):
        # forward pass
        n_examples[name] += batch['input_ids'].size(0)
        n_in_units[name] += (batch['attn_masks'] == 1).sum().item()
        task = name.split('_')[0]
        loss = model(batch, task=task, compute_loss=True)
        if task.startswith('itm'):
            # OT
            itm_loss, ot_loss = loss
            n_loss_units[name] += itm_loss.size(0)
            itm_loss = itm_loss.mean()
            if ot_loss is not None:
                ot_pos, ot_neg = ot_loss
                ot_loss = (ot_pos.sum() - ot_neg.sum()
                           ) / (ot_pos.size(0) + ot_neg.size(0))

                # NOTE: be ware of empty tensor
                ot_pos = ot_pos.mean().item()
                if not math.isnan(ot_pos):
                    task2loss[f'{name}_ot_pos'](ot_pos)
                ot_neg = ot_neg.mean().item()
                if not math.isnan(ot_neg):
                    task2loss[f'{name}_ot_neg'](ot_neg)

                loss = itm_loss + opts.itm_ot_lambda * ot_loss
                task2loss[f'{name}_xe'](itm_loss.item())
                task2loss[f'{name}_ot'](ot_loss.item())
            else:
                loss = itm_loss
        else:
            n_loss_units[name] += loss.size(0)
            loss = loss.mean()  # loss is not normalized in model

        # backward pass
        delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
        with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
                            loss_id=task2scaler[name]) as scaled_loss:
            scaled_loss.backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [p.grad.data for p in model.parameters()
                         if p.requires_grad and p.grad is not None]
                all_reduce_and_rescale_tensors(grads, float(1))
        task2loss[name](loss.item())

        # optimizer update and logging
        if (step + 1) % opts.gradient_accumulation_steps == 0:
            global_step += 1

            # learning rate scheduling
            lr_this_step = get_lr_sched(global_step, opts)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
            TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

            # log loss
            # NOTE: not gathered across GPUs for efficiency
            TB_LOGGER.log_scaler_dict({ll.name: ll.val
                                       for ll in task2loss.values()
                                       if ll.val is not None})
            TB_LOGGER.step()

            # update model params
            if opts.grad_norm != -1:
                grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                            opts.grad_norm)
                TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
            optimizer.step()
            optimizer.zero_grad()
            pbar.update(1)

            if global_step % 100 == 0:
                # monitor training throughput
                LOGGER.info(f'==============Step {global_step}===============')
                for t in train_dataloaders.keys():
                    assert all(tt == t for tt in all_gather_list(t))
                    tot_ex = sum(all_gather_list(n_examples[t]))
                    ex_per_sec = int(tot_ex / (time()-start))
                    tot_in = sum(all_gather_list(n_in_units[t]))
                    in_per_sec = int(tot_in / (time()-start))
                    tot_l = sum(all_gather_list(n_loss_units[t]))
                    l_per_sec = int(tot_l / (time()-start))
                    LOGGER.info(f'{t}: {tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec,
                                         global_step)
                LOGGER.info('===============================================')

            if global_step % opts.valid_steps == 0:
                LOGGER.info(f'Step {global_step}: start validation')
                validate(model, val_dataloaders)
                model_saver.save(model, global_step)
        if global_step >= opts.num_train_steps:
            break
    if global_step % opts.valid_steps != 0:
        LOGGER.info(f'Step {global_step}: start validation')
        validate(model, val_dataloaders)
        model_saver.save(model, global_step)
Exemplo n.º 18
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
                f"{opts.train_img_db}")
    train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db,
                                         opts.train_batch_size, True,
                                         ReDataset, re_collate, opts)
    val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db,
                                       opts.val_batch_size, False,
                                       ReEvalDataset, re_eval_collate, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    all_dbs = [opts.train_txt_db, opts.val_txt_db]
    toker = json.load(open(f'{all_dbs[0]}/meta.json'))['toker']
    assert all(toker == json.load(open(f'{db}/meta.json'))['toker']
               for db in all_dbs)
    model = UniterForReferringExpressionComprehension.from_pretrained(
        opts.model_config,
        checkpoint,
        img_dim=IMG_DIM,
        loss=opts.train_loss,
        margin=opts.margin,
        hard_ratio=opts.hard_ratio,
        mlp=opts.mlp,
    )
    model.to(device)
    model.train()
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    optimizer = build_optimizer(model, opts)

    #  Apex
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'), 'model_epoch')
        os.makedirs(join(opts.output_dir, 'results'))  # store RE predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataloader.dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    best_val_acc, best_epoch = None, None
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    if global_step == 0:
        optimizer.step()

    while True:
        for step, batch in enumerate(train_dataloader):
            if global_step >= opts.num_train_steps:
                break

            n_examples += batch['input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.sum()  # sum over vectorized loss TODO: investigate
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'============Step {global_step}=============')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info('===========================================')

        # evaluate after each epoch
        val_log, _ = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)

        # save model
        n_epoch += 1
        model_saver.save(model, n_epoch)
        LOGGER.info(f"finished {n_epoch} epochs")

        # save best model
        if best_val_acc is None or val_log['valid/acc'] > best_val_acc:
            best_val_acc = val_log['valid/acc']
            best_epoch = n_epoch
            model_saver.save(model, 'best')

        # shuffle training data for the next epoch
        train_dataloader.loader.dataset.shuffle()

        # is training finished?
        if global_step >= opts.num_train_steps:
            break

    val_log, results = validate(model, val_dataloader)
    with open(
            f'{opts.output_dir}/results/'
            f'results_{global_step}_'
            f'rank{rank}_final.json', 'w') as f:
        json.dump(results, f)
    TB_LOGGER.log_scaler_dict(val_log)
    model_saver.save(model, f'{global_step}_final')

    # print best model
    LOGGER.info(
        f'best_val_acc = {best_val_acc*100:.2f}% at epoch {best_epoch}.')
Exemplo n.º 19
0
class TrainerUniter():
    def __init__(self, config):
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.short_loss_list, self.id_list = [], [], [], [], [], []
        self.best_val_metrics, self.train_metrics = defaultdict(int), {}
        self.best_auc = 0
        self.not_improved = 0
        self.best_val_loss = 1000
        self.total_iters = 0
        self.terminate_training = False
        self.model_file = os.path.join(config['model_path'],
                                       config['model_save_name'])
        self.pretrained_model_file = None
        if config['pretrained_model_file'] is not None:
            self.pretrained_model_file = os.path.join(
                config['model_path'], config['pretrained_model_file'])
        self.start_epoch = 1
        self.config = config
        self.device = get_device()

        if not isinstance(self.config['test_loader'], list):
            self.config['test_loader'] = [self.config['test_loader']]

        # Initialize the model, optimizer and loss function
        self.init_training_params()

    def init_training_params(self):
        self.init_model()
        wandb.watch(self.model)
        self.model_saver = ModelSaver(self.model_file)

        self.init_optimizer()
        self.init_scheduler()

        if self.config['loss_func'] == 'bce_logits':
            self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(
                [self.config['pos_wt']]).to(self.device))
        elif self.config['loss_func'] == 'bce':
            self.criterion = nn.BCELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()

    def init_scheduler(self):
        if self.config['scheduler'] == 'step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=self.config['lr_decay_step'],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'multi_step':
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer,
                milestones=[5, 10, 15, 25, 40],
                gamma=self.config['lr_decay_factor'])
        elif self.config['scheduler'] == 'warmup':
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])
        elif self.config['scheduler'] == 'warmup_cosine':
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=self.config['warmup_steps'],
                num_training_steps=len(self.config['train_loader']) *
                self.config['max_epoch'])

    def init_optimizer(self):
        self.optimizer = get_optimizer(self.model, self.config)

    def init_model(self):
        # pretrained model file is the original pretrained model - load and use this to fine-tune.
        # If this argument is False, it will load the model file saved by you after fine-tuning
        if self.pretrained_model_file:
            checkpoint = torch.load(self.pretrained_model_file)
            LOGGER.info('Using pretrained UNITER base model {}'.format(
                self.pretrained_model_file))
            base_model = UniterForPretraining.from_pretrained(
                self.config['config'],
                state_dict=checkpoint['model_state_dict'],
                img_dim=IMG_DIM,
                img_label_dim=IMG_LABEL_DIM)
            self.model = MemeUniter(
                uniter_model=base_model.uniter,
                hidden_size=base_model.uniter.config.hidden_size +
                self.config["race_gender_hidden_size"],
                n_classes=self.config['n_classes'])
        else:
            self.load_model()

    def load_model(self):
        # Load pretrained model
        if self.model_file:
            checkpoint = torch.load(self.model_file)
            LOGGER.info('Using UNITER model {}'.format(self.model_file))
        else:
            checkpoint = {}

        uniter_config = UniterConfig.from_json_file(self.config['config'])
        uniter_model = UniterModel(uniter_config, img_dim=IMG_DIM)

        self.model = MemeUniter(uniter_model=uniter_model,
                                hidden_size=uniter_model.config.hidden_size +
                                self.config["race_gender_hidden_size"],
                                n_classes=self.config['n_classes'])
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def average_gradients(self, steps):
        # Used when grad_accumulation > 1
        for param in self.model.parameters():
            if param.requires_grad and param.grad is not None:
                param.grad = param.grad / steps

    def calculate_loss(self, preds, batch_label, grad_step):
        if self.config['loss_func'] == 'bce':
            preds = torch.sigmoid(preds)
        preds = preds.squeeze(1).to(
            self.device
        ) if self.config['loss_func'] == 'bce_logits' else preds.to(
            self.device)
        loss = self.criterion(
            preds,
            batch_label.to(self.device) if self.config['loss_func'] == 'ce'
            else batch_label.float().to(self.device))

        if grad_step and self.iters % self.config['gradient_accumulation'] == 0:
            loss.backward()
            self.average_gradients(steps=self.config['gradient_accumulation'])
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.config['max_grad_norm'])
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
        elif grad_step:
            loss.backward()

        if self.config['loss_func'] == 'bce':
            probs = preds
            preds = (preds > 0.5).type(torch.FloatTensor)
        elif self.config['loss_func'] == 'ce':
            probs = F.softmax(preds, dim=1)
            preds = torch.argmax(probs, dim=1)
        elif self.config['loss_func'] == 'bce_logits':
            probs = torch.sigmoid(preds)
            preds = (probs > 0.5).type(torch.FloatTensor)

        self.probs_list.append(probs.cpu().detach().numpy())
        self.preds_list.append(preds.cpu().detach().numpy())
        self.labels_list.append(batch_label.cpu().detach().numpy())
        self.loss_list.append(loss.detach().item())
        if grad_step:
            self.short_loss_list.append(loss.detach().item())

    def eval_model(self, test=False, test_idx=0):
        self.model.eval()
        self.preds_list, self.probs_list, self.labels_list, self.loss_list, self.id_list = [], [], [], [], []
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        with torch.no_grad():
            for iters, batch in enumerate(batch_loader):
                batch = self.batch_to_device(batch)
                if batch_loader.dataset.return_ids:
                    self.id_list.append(batch['ids'])
                self.eval_iter_step(iters, batch, test=test)

            self.probs_list = [
                prob for batch_prob in self.probs_list for prob in batch_prob
            ]
            self.preds_list = [
                pred for batch_pred in self.preds_list for pred in batch_pred
            ]
            self.labels_list = [
                label for batch_labels in self.labels_list
                for label in batch_labels
            ]
            self.id_list = [
                data_id for batch_id in self.id_list for data_id in batch_id
            ]

            val_loss = sum(self.loss_list) / len(self.loss_list)
            eval_metrics = standard_metrics(torch.tensor(self.probs_list),
                                            torch.tensor(self.labels_list),
                                            add_optimal_acc=True)
            # if test:
            # 	print(classification_report(np.array(self.labels_list), np.array(self.preds_list)))
        return eval_metrics, val_loss

    @torch.no_grad()
    def export_test_predictions(self, test_idx=0, threshold=0.5):
        self.model.eval()

        # Step 2: Run model on the test set (no loss!)
        # Ensure that ids are actually returned
        assert self.config['test_loader'][
            test_idx].dataset.return_ids, "Can only export test results if the IDs are returned in the test dataset."
        test_name = self.config['test_loader'][test_idx].dataset.name

        prob_list = []
        id_list = []
        for iters, batch in enumerate(self.config['test_loader'][test_idx]):
            batch = self.batch_to_device(batch)
            id_list.append(batch['ids'].cpu())
            probs = self.test_iter_step(batch)
            if self.config['loss_func'] == 'bce_logits':
                probs = torch.sigmoid(probs)
            prob_list.append(probs.detach().cpu())

        probs = torch.cat(prob_list, dim=0)
        ids = torch.cat(id_list, dim=0)
        preds = (probs > threshold).long()

        # Step 3: Export predictions
        self._export_preds(ids,
                           probs,
                           preds,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of test predictions")

    @torch.no_grad()
    def export_val_predictions(self, test=False, test_idx=0, threshold=0.5):
        batch_loader = self.config['val_loader'] if not test else self.config[
            'test_loader'][test_idx]
        test_name = batch_loader.dataset.name
        LOGGER.info("Exporting %s predictions..." % (test_name))
        self.model.eval()

        # Step 1: Find the optimal threshold on validation set
        _, _ = self.eval_model(test=test, test_idx=test_idx)
        val_probs = torch.tensor(self.probs_list)
        val_labels = torch.tensor(self.labels_list)
        if len(self.id_list) != 0:
            val_ids = torch.tensor(self.id_list)
        else:
            val_ids = torch.zeros_like(val_labels) - 1
        val_preds = (val_probs > threshold).long()

        self._export_preds(val_ids,
                           val_probs,
                           val_preds,
                           labels=val_labels,
                           file_postfix="_%s_preds.csv" % test_name)

        LOGGER.info("Finished export of %s predictions" % test_name)

    def _export_preds(self,
                      ids,
                      probs,
                      preds,
                      labels=None,
                      file_postfix="_preds.csv"):
        file_string = "id,proba,label%s\n" % (",gt"
                                              if labels is not None else "")
        for i in range(ids.shape[0]):
            file_string += "%i,%f,%i" % (ids[i].item(), probs[i].item(),
                                         preds[i].item())
            if labels is not None:
                file_string += ",%i" % labels[i].item()
            file_string += "\n"
        filepath = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + file_postfix)
        with open(filepath, "w") as f:
            f.write(file_string)
        wandb.save(filepath)  #Upload file to wandb

    def check_early_stopping(self):
        self.this_metric = self.val_loss if self.config[
            'optimize_for'] == 'loss' else self.val_metrics[
                self.config['optimize_for']]
        self.current_best = self.best_val_loss if self.config[
            'optimize_for'] == 'loss' else self.best_val_metrics[
                self.config['optimize_for']]

        new_best = self.this_metric < self.current_best if self.config[
            'optimize_for'] == 'loss' else self.this_metric > self.current_best
        if new_best:
            LOGGER.info("New High Score! Saving model...")
            self.best_val_metrics = self.val_metrics
            self.best_val_loss = self.val_loss
            wandb.log({
                'Best val metrics': self.best_val_metrics,
                'Best val loss': self.best_val_loss
            })

            if not self.config["no_model_checkpoints"]:
                self.model_saver.save(self.model)

        ### Stopping Criteria based on patience and change-in-metric-threshold ###
        diff = self.current_best - \
            self.this_metric if self.config['optimize_for'] == 'loss' else self.this_metric - \
            self.current_best
        if diff < self.config['early_stop_thresh']:
            self.not_improved += 1
            if self.not_improved >= self.config['patience']:
                self.terminate_training = True
        else:
            self.not_improved = 0
        LOGGER.info("current patience: {}".format(self.not_improved))

    def train_epoch_step(self):
        self.model.train()
        lr = self.scheduler.get_last_lr()
        self.total_iters += self.iters + 1
        self.probs_list = [
            pred for batch_pred in self.probs_list for pred in batch_pred
        ]
        self.labels_list = [
            label for batch_labels in self.labels_list
            for label in batch_labels
        ]

        # Evaluate on train set
        self.train_metrics = standard_metrics(torch.tensor(self.probs_list),
                                              torch.tensor(self.labels_list),
                                              add_optimal_acc=True)
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.loss_list,
                        self.train_metrics,
                        lr[0],
                        loss_only=False,
                        val=False)
        self.train_loss = self.loss_list[:]

        # Evaluate on dev set
        val_time = time.time()
        self.val_metrics, self.val_loss = self.eval_model()
        self.config['writer'].add_scalar("Stats/time_validation",
                                         time.time() - val_time,
                                         self.total_iters)

        # print stats
        print_stats(self.config, self.epoch, self.train_metrics,
                    self.train_loss, self.val_metrics, self.val_loss,
                    self.start, lr[0])

        # log validation stats in tensorboard
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.val_loss,
                        self.val_metrics,
                        lr[0],
                        loss_only=False,
                        val=True)

        # Check for early stopping criteria
        self.check_early_stopping()
        self.probs_list = []
        self.preds_list = []
        self.labels_list = []
        self.loss_list = []
        self.id_list = []

        self.train_loss = sum(self.train_loss) / len(self.train_loss)
        del self.val_metrics
        del self.val_loss

    def end_training(self):
        # Termination message
        print("\n" + "-" * 100)
        if self.terminate_training:
            LOGGER.info(
                "Training terminated early because the Validation {} did not improve for  {}  epochs"
                .format(self.config['optimize_for'], self.config['patience']))
        else:
            LOGGER.info(
                "Maximum epochs of {} reached. Finished training !!".format(
                    self.config['max_epoch']))

        print_test_stats(self.best_val_metrics, test=False)

        print("-" * 50 + "\n\t\tEvaluating on test set\n" + "-" * 50)
        if not self.config["no_model_checkpoints"]:
            if os.path.isfile(self.model_file):
                self.load_model()
                self.model.to(self.device)
            else:
                raise ValueError(
                    "No Saved model state_dict found for the chosen model...!!! \nAborting evaluation on test set..."
                    .format(self.config['model_name']))

            self.export_val_predictions(
            )  # Runs evaluation, no need to run it again here
            val_probs = torch.tensor(self.probs_list)
            val_labels = torch.tensor(self.labels_list)
            threshold = 0.5  # the default threshelod for binary classification
            # Uncomment below line if you have implemented this optional feature
            # threshold = find_optimal_threshold(val_probs, val_labels, metric="accuracy")
            best_val_metrics = standard_metrics(val_probs,
                                                val_labels,
                                                threshold=threshold,
                                                add_aucroc=False)
            LOGGER.info(
                "Optimal threshold on validation dataset: %.4f (accuracy=%4.2f%%)"
                % (threshold, 100.0 * best_val_metrics["accuracy"]))

            # Testing is in the standard form not possible, as we do not have any labels (gives an error in standard_metrics)
            # Instead, we should write out the predictions in the form of the leaderboard
            self.test_metrics = dict()
            for test_idx in range(len(self.config['test_loader'])):
                test_name = self.config['test_loader'][test_idx].dataset.name
                LOGGER.info("Export and testing on %s..." % test_name)
                if hasattr(self.config['test_loader'][test_idx].dataset, "data") and \
                   hasattr(self.config['test_loader'][test_idx].dataset.data, "labels") and \
                   self.config['test_loader'][test_idx].dataset.data.labels[0] == -1:  # Step 1: Find the optimal threshold on validation set
                    self.export_test_predictions(test_idx=test_idx,
                                                 threshold=threshold)
                    self.test_metrics[test_name] = dict()
                else:
                    test_idx_metrics, _ = self.eval_model(test=True,
                                                          test_idx=test_idx)
                    self.test_metrics[test_name] = test_idx_metrics
                    print_test_stats(test_idx_metrics, test=True)
                    self.export_val_predictions(test=True,
                                                test_idx=test_idx,
                                                threshold=threshold)
        else:
            LOGGER.info(
                "No model checkpoints were saved. Hence, testing will be skipped."
            )
            self.test_metrics = dict()

        self.export_metrics()

        self.config['writer'].close()

        if self.config['remove_checkpoints']:
            LOGGER.info("Removing checkpoint %s..." % self.model_file)
            os.remove(self.model_file)

    def export_metrics(self):
        metric_export_file = os.path.join(
            self.config['model_path'],
            self.config['model_save_name'].rsplit(".", 1)[0] + "_metrics.json")
        metric_dict = {}
        metric_dict["dev"] = self.best_val_metrics
        metric_dict["dev"]["loss"] = self.best_val_loss
        metric_dict["train"] = self.train_metrics
        metric_dict["train"]["loss"] = sum(
            self.train_loss) / len(self.train_loss) if isinstance(
                self.train_loss, list) else self.train_loss
        if hasattr(self, "test_metrics") and len(self.test_metrics) > 0:
            metric_dict["test"] = self.test_metrics

        with open(metric_export_file, "w") as f:
            json.dump(metric_dict, f, indent=4)

    def train_main(self, cache=False):
        print("\n\n" + "=" * 100 + "\n\t\t\t\t\t Training Network\n" +
              "=" * 100)

        self.start = time.time()
        print("\nBeginning training at:  {} \n".format(
            datetime.datetime.now()))

        self.model.to(self.device)

        for self.epoch in range(self.start_epoch,
                                self.config['max_epoch'] + 1):
            train_times = []
            for self.iters, self.batch in enumerate(
                    self.config['train_loader']):
                self.model.train()

                iter_time = time.time()
                self.batch = self.batch_to_device(self.batch)
                self.train_iter_step()
                train_times.append(time.time() - iter_time)

                # Loss only logging
                if (self.total_iters + self.iters +
                        1) % self.config['log_every'] == 0:
                    log_tensorboard(self.config,
                                    self.config['writer'],
                                    self.model,
                                    self.epoch,
                                    self.iters,
                                    self.total_iters,
                                    self.short_loss_list,
                                    loss_only=True,
                                    val=False)
                    self.config['writer'].add_scalar(
                        'Stats/time_per_train_iter', mean(train_times),
                        (self.iters + self.total_iters + 1))
                    self.config['writer'].add_scalar(
                        'Stats/learning_rate',
                        self.scheduler.get_last_lr()[0],
                        (self.iters + self.total_iters + 1))
                    train_times = []
                    self.short_loss_list = []
            self.train_epoch_step()

            if self.terminate_training:
                break

        self.end_training()
        return self.best_val_metrics, self.test_metrics

    def batch_to_device(self, batch):
        batch = {
            k: (v.to(self.device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        return batch

    def eval_iter_step(self, iters, batch, test):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False,
                           gender_race_probs=batch['gender_race_probs'])
        self.calculate_loss(preds, batch['labels'], grad_step=False)

    def train_iter_step(self):
        # Forward pass
        self.preds = self.model(
            img_feat=self.batch['img_feat'],
            img_pos_feat=self.batch['img_pos_feat'],
            input_ids=self.batch['input_ids'],
            position_ids=self.batch['position_ids'],
            attention_mask=self.batch['attn_mask'],
            gather_index=self.batch['gather_index'],
            output_all_encoded_layers=False,
            gender_race_probs=self.batch['gender_race_probs'])
        self.calculate_loss(self.preds, self.batch['labels'], grad_step=True)

    def test_iter_step(self, batch):
        # Forward pass
        preds = self.model(img_feat=batch['img_feat'],
                           img_pos_feat=batch['img_pos_feat'],
                           input_ids=batch['input_ids'],
                           position_ids=batch['position_ids'],
                           attention_mask=batch['attn_mask'],
                           gather_index=batch['gather_index'],
                           output_all_encoded_layers=False,
                           gender_race_probs=batch['gender_race_probs'])
        return preds.squeeze()
Exemplo n.º 20
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    if hvd.rank() == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
        # store ITM predictions
        os.makedirs(join(opts.output_dir, 'results_val'))
        os.makedirs(join(opts.output_dir, 'results_test'))
        os.makedirs(join(opts.output_dir, 'results_train'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    # train_examples = None
    LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
                f"{opts.train_img_dbs}")
    # check multiple DBs
    assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
        "train txt_db and img_db have different length"

    # load DBs and image dirs
    all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
                                 opts.num_bb, opts.compressed_db)
    # train
    LOGGER.info(f"Loading Train Dataset "
                f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
    train_datasets = []
    for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
        img_db = all_img_dbs[img_path]
        txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
        train_datasets.append(
            ItmRankDataset(txt_db, img_db, opts.negative_size))
    train_dataset = ConcatDataset(train_datasets)

    # val
    LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
    val_img_db = all_img_dbs[opts.val_img_db]
    val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
    val_dataset = ItmValDataset(val_txt_db, val_img_db,
                                opts.inf_minibatch_size)
    val_dataloader = build_dataloader(val_dataset, itm_val_collate, False,
                                      opts)
    # eval
    LOGGER.info(f"Loading val, test Dataset for full evaluation: "
                f"{opts.val_txt_db}, {opts.val_img_db}"
                f"{opts.test_txt_db}, {opts.test_img_db}")
    eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
                                      opts.inf_minibatch_size)
    eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
                                       False, opts)
    test_img_db = all_img_dbs[opts.test_img_db]
    test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
    eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
                                       opts.inf_minibatch_size)
    eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
                                        False, opts)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = UniterForImageTextRetrieval.from_pretrained(opts.model_config,
                                                        state_dict=checkpoint,
                                                        img_dim=IMG_DIM,
                                                        margin=opts.margin)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    global_step = 0
    LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(train_dataset) * hvd.size())
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()

    n_examples = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        train_dataloader = build_dataloader(train_dataset, itm_rank_collate,
                                            True, opts)
        for step, batch in enumerate(train_dataloader):
            n_examples += batch['input_ids'].size(0)
            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())
            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                # NOTE: not gathered across GPUs for efficiency
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info(f'------------Step {global_step}-------------')
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)
                    LOGGER.info(f'-------------------------------------------')

                if global_step % opts.valid_steps == 0:
                    if opts.full_val:
                        LOGGER.info(
                            f"========================== Step {global_step} "
                            f"==========================")
                        val_log = evaluate(model, eval_loader_val)
                        TB_LOGGER.log_scaler_dict(
                            {f"valid/{k}": v
                             for k, v in val_log.items()})
                        LOGGER.info(f"image retrieval R1: "
                                    f"{val_log['img_r1']*100:.2f},\n"
                                    f"image retrieval R5: "
                                    f"{val_log['img_r5']*100:.2f},\n"
                                    f"image retrieval R10: "
                                    f"{val_log['img_r10']*100:.2f}\n"
                                    f"text retrieval R1: "
                                    f"{val_log['txt_r1']*100:.2f},\n"
                                    f"text retrieval R5: "
                                    f"{val_log['txt_r5']*100:.2f},\n"
                                    f"text retrieval R10: "
                                    f"{val_log['txt_r10']*100:.2f}")
                        LOGGER.info("================================="
                                    "=================================")
                    else:
                        val_log = validate(model, val_dataloader)
                        TB_LOGGER.log_scaler_dict(val_log)
                    model_saver.save(model, global_step)

            if global_step >= opts.num_train_steps:
                break

        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")

    pbar.close()
    if opts.num_train_steps % opts.valid_steps != 0:
        # final validation
        val_log = validate(model, val_dataloader)
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)

    # evaluation
    for split, loader in [('val', eval_loader_val),
                          ('test', eval_loader_test)]:
        eval_log = evaluate(model, loader)
        TB_LOGGER.log_scaler_dict(
            {f"eval/{split}_{k}": v
             for k, v in eval_log.items()})
        if hvd.rank() != 0:
            continue
        LOGGER.info(
            f"========================= {split} ===========================\n"
            f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
            f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
            f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
            f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
            f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
            f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
    LOGGER.info("=========================================================")
Exemplo n.º 21
0
def main(opts):
    hvd.init()
    n_gpu = hvd.size()
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if hvd.rank() != 0:
        LOGGER.disabled = True

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)
    opts.task = 'tvc'

    # train_examples = None
    LOGGER.info(f"Loading the whole video dataset {opts.sub_txt_db}, "
                f"{opts.vfeat_db}")
    video_db = load_video_sub_dataset(opts.vfeat_db, opts.sub_txt_db,
                                      opts.vfeat_interval, opts)

    # data loaders
    # train
    LOGGER.info(f"Loading train dataset {opts.train_db}")
    train_cap = CaptionTokLmdb(opts.train_db, opts.max_txt_len)
    train_dset = TvcTrainDataset(video_db, train_cap, opts.max_cap_per_vid)
    LOGGER.info(f"{sum(all_gather_list(len(train_dset)))} samples loaded")
    train_loader = build_dataloader(train_dset, opts.train_batch_size,
                                    TvcTrainDataset.collate, True, opts)

    # val
    LOGGER.info(f"Loading val dataset {opts.val_db}")
    val_cap = CaptionTokLmdb(opts.val_db, -1)
    val_dset = TvcValDataset(video_db, val_cap, -1)
    val_loader = build_dataloader(val_dset, opts.val_batch_size,
                                  TvcValDataset.collate, False, opts)
    if hvd.rank() == 0:
        evaluator = TVCEval(opts.val_ref)
    else:
        evaluator = NoOp()

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    img_pos_embed_weight_key = "v_encoder.f_encoder.img_embeddings" +\
        ".position_embeddings.weight"
    if img_pos_embed_weight_key in checkpoint:
        max_frm_seq_len = len(checkpoint[img_pos_embed_weight_key])
    else:
        max_frm_seq_len = MAX_FRM_SEQ_LEN

    model = HeroForTvc.from_pretrained(opts.model_config,
                                       state_dict=checkpoint,
                                       vfeat_dim=VFEAT_DIM,
                                       max_frm_seq_len=max_frm_seq_len,
                                       lsr=opts.lsr)

    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=opts.fp16,
                                      opt_level='O2')

    # assumes roberta tokenizer only
    if hvd.local_rank() == 0:
        # quick hack to prevent multi-process download collision
        toker = RobertaTokenizer.from_pretrained('roberta-base')
        all_gather_list(None)
    else:
        all_gather_list(None)
        toker = RobertaTokenizer.from_pretrained('roberta-base')
    bos = toker.convert_tokens_to_ids(['<s>'])[0]
    eos = toker.convert_tokens_to_ids(['</s>'])[0]
    generator = TvcGenerator(model, opts.max_gen_step, bos, eos, opts.fp16)

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'))  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    train_loss = RunningMeter('loss')
    n_vid = 0
    n_cap = 0
    n_epoch = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    model.train()
    while True:
        for step, batch in enumerate(train_loader):
            n_vid += opts.train_batch_size
            n_cap += batch['cap_input_ids'].size(0)

            loss = model(batch, compute_loss=True)
            loss = loss.mean()
            train_loss(loss.item())

            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
                if not delay_unscale:
                    # gather gradients from every processes
                    # do this before unscaling to make sure every process uses
                    # the same gradient scale
                    grads = [
                        p.grad.data for p in model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for i, param_group in enumerate(optimizer.param_groups):
                    if i == 0 or i == 1:
                        param_group['lr'] = lr_this_step * opts.lr_mul
                    elif i == 2 or i == 3:
                        param_group['lr'] = lr_this_step
                    else:
                        raise ValueError()
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                TB_LOGGER.add_scalar(train_loss.name, train_loss.val,
                                     global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    grad_norm = clip_grad_norm_(amp.master_params(optimizer),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
                optimizer.step()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    LOGGER.info('-------------------------------------------')
                    LOGGER.info(f'Step {global_step}:')
                    tot_vid = sum(all_gather_list(n_vid))
                    vid_per_sec = int(tot_vid / (time() - start))
                    LOGGER.info(f'{tot_vid} videos trained at '
                                f'{vid_per_sec} vid/s')
                    tot_cap = sum(all_gather_list(n_cap))
                    cap_per_sec = int(tot_cap / (time() - start))
                    TB_LOGGER.add_scalar(f'perf/vid_per_s', vid_per_sec,
                                         global_step)
                    TB_LOGGER.add_scalar(f'perf/cap_per_s', cap_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    LOGGER.info('===========================================')
                    LOGGER.info(f"Step {global_step}: start validation")
                    val_log, results = validate(val_loader, generator, toker,
                                                evaluator)
                    if hvd.rank() == 0:
                        save_jsonl(
                            results, f"{opts.output_dir}/results/"
                            f"/results_{global_step}.jsonl")
                    TB_LOGGER.log_scaler_dict(val_log)
                    LOGGER.info('===========================================')
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        n_epoch += 1
        LOGGER.info(f"finished {n_epoch} epochs")
        if global_step >= opts.num_train_steps:
            break

    LOGGER.info('===========================================')
    if global_step % opts.valid_steps != 0:
        val_log, results = validate(val_loader, generator, toker, evaluator)
        if hvd.rank() == 0:
            save_jsonl(
                results, f"{opts.output_dir}/results/"
                f"/results_{global_step}.jsonl")
        TB_LOGGER.log_scaler_dict(val_log)
        model_saver.save(model, global_step)
Exemplo n.º 22
0
def main(opts):

    os.makedirs(opts.output_dir)
    os.makedirs(join(opts.output_dir, 'ckpt'))
    model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))

    # train
    train_dataset = MemeAIDataset(json_path='/home/data/meme_json/train.json',
                                  npz_folder='/home/data/faster_cnn_feature/',
                                  mode='train')
    train_loader = DataLoader(train_dataset,
                              batch_size=opts.train_batch_size,
                              shuffle=True,
                              num_workers=opts.n_workers,
                              collate_fn=collate_fn)
    train_loader = PrefetchLoader(train_loader)

    # val
    val_dataset = MemeAIDataset(json_path='/home/data/meme_json/dev.json',
                                npz_folder='/home/data/faster_cnn_feature/',
                                mode='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=opts.inf_minibatch_size,
                            shuffle=False,
                            num_workers=opts.n_workers,
                            collate_fn=collate_fn)
    val_loader = PrefetchLoader(val_loader)

    # Prepare model
    if opts.checkpoint:
        checkpoint = torch.load(opts.checkpoint)
    else:
        checkpoint = {}

    model = Meme.from_pretrained(opts.model_config,
                                 state_dict=checkpoint,
                                 img_dim=IMG_DIM)
    model.init_output()  # pretrain ITM head is different from ranking head
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=opts.learning_rate)

    for epoch in range(opts.epoch):
        print('epoch {}/ {}'.format(epoch, opts.epoch))
        pbar = tqdm(total=len(train_loader))

        model.train()
        preds = None
        gt = None

        for step, batch in enumerate(train_loader):
            x = batch[0]
            x['input_ids'] = x['input_ids'].to(device)
            x['position_ids'] = x['position_ids'].to(device)
            x['img_feat'] = x['img_feat'].to(device)
            x['img_pos_feat'] = x['img_pos_feat'].to(device)
            x['attn_masks'] = x['attn_masks'].to(device)
            x['gather_index'] = x['gather_index'].to(device)
            y = batch[1].to(device)

            pred = model(x)

            if preds is None:
                preds = torch.sigmoid(pred)
                gt = y
            else:
                preds = torch.cat((preds, torch.sigmoid(pred)), dim=0)
                gt = torch.cat((gt, y), dim=0)

            loss = F.binary_cross_entropy(torch.sigmoid(pred), y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.update(1)

        model.eval()
        with torch.no_grad():
            preds = preds.detach().cpu().numpy().reshape(len(preds), )
            gt = gt.cpu().numpy()
            roc = roc_auc_score(gt, preds)
            acc = accuracy_score(gt, np.around(preds))

        train_log = {'train/roc': roc, 'train/acc': acc}
        val_log = validate(model, val_loader)

        LOGGER.info(train_log)
        LOGGER.info(val_log)

        model_saver.save(model, epoch)
        pbar.close()