def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True):
    inference_batch_size = args.per_gpu_eval_batch_size  # * max(1, args.n_gpu)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(
        inference_dataset,
        batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier()  # directory created

    _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(
        args, model, inference_dataloader, is_query_inference=is_query_inference, prefix=prefix)

    logger.info("merging embeddings")

    # preserve to memory
    full_embedding = barrier_array_merge(
        args,
        _embedding,
        prefix=prefix +
        "_emb_p_",
        load_cache=False,
        only_load_in_master=True)
    full_embedding2id = barrier_array_merge(
        args,
        _embedding2id,
        prefix=prefix +
        "_embid_p_",
        load_cache=False,
        only_load_in_master=True)

    return full_embedding, full_embedding2id
Exemplo n.º 2
0
def embedding_inference(args, path, model, fn, bz, num_workers=2, is_query=True):
    f = open(path, encoding="utf-8")
    model = model.module if hasattr(model, "module") else model
    sds = StreamingDataset(f, fn)
    loader = DataLoader(sds, batch_size=bz, num_workers=0)
    emb_list, id_list = [], []
    model.eval()
    for i, batch in tqdm(enumerate(loader), desc="Eval", disable=args.local_rank not in [-1, 0]):
        batch = tuple(t.to(args.device) for t in batch)
        with torch.no_grad():
            inputs = {"input_ids": batch[0].long(
            ), "attention_mask": batch[1].long()}
            idx = batch[3].long()
            if is_query:
                embs = model.query_emb(**inputs)
            else:
                embs = model.body_emb(**inputs)
            if len(embs.shape) == 3:
                B, C, E = embs.shape
                # [b1c1, b1c2, b1c3, b1c4, b2c1 ....]
                embs = embs.view(B*C, -1)
                idx = idx.repeat_interleave(C)

            assert embs.shape[0] == idx.shape[0]
            emb_list.append(embs.detach().cpu().numpy())
            id_list.append(idx.detach().cpu().numpy())
    f.close()
    emb_arr = np.concatenate(emb_list, axis=0)
    id_arr = np.concatenate(id_list, axis=0)

    return emb_arr, id_arr
Exemplo n.º 3
0
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference = True, load_cache=False):
    inference_batch_size = args.per_gpu_eval_batch_size #* max(1, args.n_gpu)
    #inference_dataloader = StreamingDataLoader(f, fn, batch_size=inference_batch_size, num_workers=1)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(inference_dataset, batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier() # directory created

    if (args.emb_file_multi_split_num > 0) and ("passage" in prefix):
        # extra handling the memory problem by specifying the size of file
        _, _ = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix)
        # dist.barrier()
        full_embedding = None
        full_embedding2id = None # TODO: loading ids for first_worker()
    else:
        if load_cache:
            _embedding = None
            _embedding2id = None
        else:
            _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix)

        not_loading = args.split_ann_search and ("passage" in prefix)
        # preserve to memory
        full_embedding = barrier_array_merge(args, _embedding, prefix = prefix + "_emb_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading)
        _embedding=None
        del _embedding
        full_embedding2id = barrier_array_merge(args, _embedding2id, prefix = prefix + "_embid_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading)
        logger.info( f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}" )
        _embedding2id=None
        del  _embedding2id

    return full_embedding, full_embedding2id
Exemplo n.º 4
0
def main():
    args = get_arguments()
    set_env(args)

    config, tokenizer, model, configObj = load_stuff(args.train_model_type,
                                                     args)

    # Training
    if args.do_train:
        logger.info("Training/evaluation parameters %s", args)

        def train_fn(line, i):
            return configObj.process_fn(line, i, tokenizer, args)

        train_path = os.path.join(args.data_dir, args.train_file)
        with open(train_path, encoding="utf-8-sig") as f:
            train_batch_size = args.per_gpu_train_batch_size * \
                max(1, args.n_gpu)
            sds = StreamingDataset(f, train_fn)
            train_dataloader = DataLoader(sds,
                                          batch_size=train_batch_size,
                                          num_workers=1)
            global_step, tr_loss = train(args, model, tokenizer,
                                         train_dataloader)
            logger.info(" global_step = %s, average loss = %s", global_step,
                        tr_loss)

    save_checkpoint(args, model, tokenizer)

    results = evaluation(args, model, tokenizer)
    return results
Exemplo n.º 5
0
def evaluate_dev(args, model, passage_cache, source=""):

    dev_query_collection_path = os.path.join(args.data_dir,
                                             "dev-query{}".format(source))
    dev_query_cache = EmbeddingCache(dev_query_collection_path)

    logger.info('NLL validation ...')

    model.eval()

    log_result_step = 100
    batches = 0
    total_loss = 0.0
    total_correct_predictions = 0

    with dev_query_cache:
        dev_data_path = os.path.join(args.data_dir,
                                     "dev-data{}".format(source))
        with open(dev_data_path, 'r') as f:
            dev_data = f.readlines()
        dev_dataset = StreamingDataset(
            dev_data,
            GetTrainingDataProcessingFn(args,
                                        dev_query_cache,
                                        passage_cache,
                                        shuffle=False))
        dev_dataloader = DataLoader(dev_dataset,
                                    batch_size=args.train_batch_size * 2)

        for i, batch in enumerate(dev_dataloader):
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)
            loss.backward()  # get CUDA oom without this
            model.zero_grad()
            total_loss += loss.item()
            total_correct_predictions += correct_cnt
            batches += 1
            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d , loss=%f ', i, loss.item())

    total_loss = total_loss / batches
    total_samples = batches * args.train_batch_size * torch.distributed.get_world_size(
    )
    correct_ratio = float(total_correct_predictions / total_samples)
    logger.info(
        'NLL Validation: loss = %f. correct prediction ratio  %d/%d ~  %f',
        total_loss, total_correct_predictions, total_samples, correct_ratio)

    model.train()
    return total_loss, correct_ratio
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True):
    inference_batch_size = args.per_gpu_eval_batch_size  # * max(1, args.n_gpu)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(inference_dataset,
                                      batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier()  # directory created

    _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(
        args,
        model,
        inference_dataloader,
        is_query_inference=is_query_inference,
        prefix=prefix)

    logger.info("merging embeddings")

    not_loading = args.split_ann_search and ("passage" in prefix)

    # preserve to memory
    full_embedding = barrier_array_merge(args,
                                         _embedding,
                                         prefix=prefix + "_emb_p_",
                                         load_cache=False,
                                         only_load_in_master=True,
                                         not_loading=not_loading)
    _embedding = None
    del _embedding
    logger.info(
        f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}"
    )

    full_embedding2id = barrier_array_merge(args,
                                            _embedding2id,
                                            prefix=prefix + "_embid_p_",
                                            load_cache=False,
                                            only_load_in_master=True)

    return full_embedding, full_embedding2id
Exemplo n.º 7
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    save_no = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0 and global_step % args.save_steps < args.save_steps / 20:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)
                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                dev_ndcg = ndcg_json['ndcg']
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info("Total ann queries: %d", len(ann_training_data))
                if args.triplet:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                else:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data))

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    tb_writer.add_scalar("dev_ndcg", dev_ndcg,
                                         ann_checkpoint_no)
                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device) for t in batch)
        step += 1

        if args.triplet:
            inputs = {
                "query_ids": batch[0].long(),
                "attention_mask_q": batch[1].long(),
                "input_ids_a": batch[3].long(),
                "attention_mask_a": batch[4].long(),
                "input_ids_b": batch[6].long(),
                "attention_mask_b": batch[7].long()
            }
        else:
            inputs = {
                "input_ids_a": batch[0].long(),
                "attention_mask_a": batch[1].long(),
                "input_ids_b": batch[3].long(),
                "attention_mask_b": batch[4].long(),
                "labels": batch[6]
            }

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(**inputs)
        else:
            with model.no_sync():
                outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

                save_no += 1

                if save_no > 1:
                    ann_no, ann_path, ndcg_json = get_latest_ann_data(
                        args.ann_dir)
                    while (ann_no == last_ann_no):
                        print("Waiting for new ann_data. Sleeping for 1hr!!")
                        time.sleep(3600)
                        ann_no, ann_path, ndcg_json = get_latest_ann_data(
                            args.ann_dir)

            dist.barrier()

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 8
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(
        1, args.n_gpu)  #nll loss for query
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer = get_optimizer(
        args,
        model,
        weight_decay=args.weight_decay,
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0
    iter_count = 0

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps)

    global_step = 0
    if args.model_name_or_path != "bert-base-uncased":
        saved_state = load_states_from_checkpoint(args.model_name_or_path)
        global_step = _load_saved_state(model, optimizer, scheduler,
                                        saved_state)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

        #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
        #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
        #if is_first_worker():
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
    print(args.num_epoch)
    #step = global_step
    print(step, args.max_steps, global_step)

    global_step = 0
    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:

            if args.num_epoch == 0:
                #print('yes')
                # check if new ann training data is availabe
                ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
                #print(ann_path)
                #print(ann_no)
                #print(ndcg_json)
                if ann_path is not None and ann_no != last_ann_no:
                    logger.info("Training on new add data at %s", ann_path)
                    time.sleep(180)
                    with open(ann_path, 'r') as f:
                        #print(ann_path)
                        ann_training_data = f.readlines()
                    logger.info("Training data line count: %d",
                                len(ann_training_data))
                    ann_training_data = [
                        l for l in ann_training_data
                        if len(l.split('\t')[2].split(',')) > 1
                    ]
                    logger.info("Filtered training data line count: %d",
                                len(ann_training_data))
                    #ann_checkpoint_path = ndcg_json['checkpoint']
                    #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                    aligned_size = (len(ann_training_data) //
                                    args.world_size) * args.world_size
                    ann_training_data = ann_training_data[:aligned_size]

                    logger.info("Total ann queries: %d",
                                len(ann_training_data))
                    if args.triplet:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTripletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset, batch_size=args.train_batch_size)
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset,
                            batch_size=args.train_batch_size * 2)
                    train_dataloader_iter = iter(train_dataloader)

                    # re-warmup
                    if not args.single_warmup:
                        scheduler = get_linear_schedule_with_warmup(
                            optimizer,
                            num_warmup_steps=args.warmup_steps,
                            num_training_steps=len(ann_training_data))

                    if args.local_rank != -1:
                        dist.barrier()

                    if is_first_worker():
                        # add ndcg at checkpoint step used instead of current step
                        #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no)
                        #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no)
                        #if 'top20_trivia' in ndcg_json:
                        #    tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no)
                        #    tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no)
                        if last_ann_no != -1:
                            tb_writer.add_scalar("epoch", last_ann_no,
                                                 global_step - 1)
                        tb_writer.add_scalar("epoch", ann_no, global_step)
                    last_ann_no = ann_no
            elif step == 0:

                train_data_path = os.path.join(args.data_dir, "train-data")
                with open(train_data_path, 'r') as f:
                    training_data = f.readlines()
                if args.triplet:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size)
                else:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size * 2)
                all_batch = [b for b in train_dataloader]
                logger.info("Total batch count: %d", len(all_batch))
                train_dataloader_iter = iter(train_dataloader)

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            if args.num_epoch != 0:
                iter_count += 1
                if is_first_worker():
                    tb_writer.add_scalar("epoch", iter_count - 1,
                                         global_step - 1)
                    tb_writer.add_scalar("epoch", iter_count, global_step)
            #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
            #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
            #if is_first_worker():
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None:
                with open(ann_path, 'r') as f:
                    print(ann_path)
                    ann_training_data = f.readlines()
                logger.info("Training data line count: %d",
                            len(ann_training_data))
                ann_training_data = [
                    l for l in ann_training_data
                    if len(l.split('\t')[2].split(',')) > 1
                ]
                logger.info("Filtered training data line count: %d",
                            len(ann_training_data))

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]
                train_dataset = StreamingDataset(
                    ann_training_data,
                    GetTrainingDataProcessingFn(args, query_cache,
                                                passage_cache))
                train_dataloader = DataLoader(
                    train_dataset, batch_size=args.train_batch_size * 2)

            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)
            dist.barrier()

        if args.num_epoch != 0 and iter_count > args.num_epoch:
            break

        step += 1
        if args.triplet:
            loss = triplet_fwd_pass(args, model, batch)
        else:
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                _save_checkpoint(args, model, optimizer, scheduler,
                                 global_step)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 9
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    tr_acc, logging_acc = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            if not args.reset_iter:
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"

            if (step + 1) % args.gradient_accumulation_steps == 0:

                outputs = model(*batch)
            else:
                with model.no_sync():
                    # print('???',*batch)
                    # assert 1==0
                    outputs = model(*batch)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
            acc = outputs[1]
            #print('???',acc)
            if is_first_worker():
                print(*batch)
                assert 1 == 0

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
                acc = acc.float().mean()
                #print('???',acc)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                acc = acc / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()

            tr_loss += loss.item()
            tr_acc += acc.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if is_first_worker(
                ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    if 'fairseq' not in args.train_model_type:
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                    else:
                        torch.save(model.state_dict(),
                                   os.path.join(output_dir, 'model.pt'))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)
                dist.barrier()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training and global_step % (
                            args.logging_steps_per_eval *
                            args.logging_steps) == 0:
                        model.eval()
                        reranking_mrr, full_ranking_mrr = passage_dist_eval(
                            args, model, tokenizer)
                        if is_first_worker():
                            print("Reranking/Full ranking mrr: {0}/{1}".format(
                                str(reranking_mrr), str(full_ranking_mrr)))
                            mrr_dict = {
                                "reranking": float(reranking_mrr),
                                "full_raking": float(full_ranking_mrr)
                            }
                            tb_writer.add_scalars("mrr", mrr_dict, global_step)
                            print(args.output_dir)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    acc_scalar = (tr_acc - logging_acc) / args.logging_steps
                    logs["acc"] = acc_scalar
                    logging_acc = tr_acc

                    if is_first_worker():
                        for key, value in logs.items():
                            print(key, type(value))
                            tb_writer.add_scalar(key, value, global_step)
                        tb_writer.add_scalar("epoch", m_epoch, global_step)
                        print(json.dumps({**logs, **{"step": global_step}}))
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 10
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        count = 0
        avg_cls_norm = 0
        loss_avg = 0
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            # if not args.reset_iter:
            #     if steps_trained_in_current_epoch > 0:
            #         steps_trained_in_current_epoch -= 1
            #         continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"
            with torch.no_grad():
                outputs = model(*batch)
                cls_norm = outputs[1]
                loss = outputs[0]

                count += 1
                avg_cls_norm += float(cls_norm.cpu().data)
                loss_avg += float(loss.cpu().data)
                print(
                    "SEED-Encoder norm: ",
                    cls_norm,
                )
                #print("loss: ",loss)
                #assert 1==0
                #print("optimus norm: ",cls_norm)

            if count == 1024:
                # print('avg_cls_norm: ',float(avg_cls_norm)/count)
                print('avg_cls_sim: ', float(avg_cls_norm) / count)
                print('avg_loss: ', float(loss_avg) / count)
                return

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 11
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    def optimizer_to(optim, device):
        for param in optim.state.values():
            # Not sure there are any global tensors in the state dict
            if isinstance(param, torch.Tensor):
                param.data = param.data.to(device)
                if param._grad is not None:
                    param._grad.data = param._grad.data.to(device)
            elif isinstance(param, dict):
                for subparam in param.values():
                    if isinstance(subparam, torch.Tensor):
                        subparam.data = subparam.data.to(device)
                        if subparam._grad is not None:
                            subparam._grad.data = subparam._grad.data.to(
                                device)

    torch.cuda.empty_cache()

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"),
                       map_location='cpu'))
    optimizer_to(optimizer, args.device)

    model.to(args.device)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    is_hypersphere_training = (args.hyper_align_weight > 0
                               or args.hyper_unif_weight > 0)
    if is_hypersphere_training:
        logger.info(
            f"training with hypersphere property regularization, align weight {args.hyper_align_weight}, unif weight {args.hyper_unif_weight}"
        )
    if not args.dual_training:
        args.dual_loss_weight = 0.0

    tr_loss_dict = {}

    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    # dev_ndcg = 0
    step = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)
        if os.path.isfile(os.path.join(
                args.model_name_or_path,
                "scheduler.pt")) and args.load_optimizer_scheduler:
            # Load in optimizer and scheduler states
            scheduler.load_state_dict(
                torch.load(
                    os.path.join(args.model_name_or_path, "scheduler.pt")))

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(
                args.ann_dir, is_grouped=(args.grouping_ann_data > 0))
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)

                time.sleep(30)  # wait until transmission finished

                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                # marcodev_ndcg = ndcg_json['marcodev_ndcg']
                logging.info(f"loading:\n{ndcg_json}")
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info(
                    "Total ann queries: %d",
                    len(ann_training_data) if args.grouping_ann_data < 0 else
                    len(ann_training_data) * args.grouping_ann_data)

                if args.grouping_ann_data > 0:
                    if args.polling_loaded_data_batch_from_group:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_polling(
                                args, query_cache, passage_cache))
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_origin(
                                args, query_cache, passage_cache))
                else:
                    if not args.dual_training:
                        if args.triplet:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTripletTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                        else:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                    else:
                        # return quadruplet
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetQuadrapuletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))

                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data)
                        if args.grouping_ann_data < 0 else
                        len(ann_training_data) * args.grouping_ann_data)

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    for key in ndcg_json:
                        if "marcodev" in key:
                            tb_writer.add_scalar(key, ndcg_json[key],
                                                 ann_checkpoint_no)

                    if 'trec2019_ndcg' in ndcg_json:
                        tb_writer.add_scalar("trec2019_ndcg",
                                             ndcg_json['trec2019_ndcg'],
                                             ann_checkpoint_no)

                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        # original way
        if args.grouping_ann_data <= 0:
            batch = tuple(t.to(args.device) for t in batch)
            if args.triplet:
                inputs = {
                    "query_ids": batch[0].long(),
                    "attention_mask_q": batch[1].long(),
                    "input_ids_a": batch[3].long(),
                    "attention_mask_a": batch[4].long(),
                    "input_ids_b": batch[6].long(),
                    "attention_mask_b": batch[7].long()
                }
                if args.dual_training:
                    inputs["neg_query_ids"] = batch[9].long()
                    inputs["attention_mask_neg_query"] = batch[10].long()
                    inputs["prime_loss_weight"] = args.prime_loss_weight
                    inputs["dual_loss_weight"] = args.dual_loss_weight
            else:
                inputs = {
                    "input_ids_a": batch[0].long(),
                    "attention_mask_a": batch[1].long(),
                    "input_ids_b": batch[3].long(),
                    "attention_mask_b": batch[4].long(),
                    "labels": batch[6]
                }
        else:
            # the default collate_fn will convert item["q_pos"] into batch format ... I guess
            inputs = {
                "query_ids": batch["q_pos"][0].to(args.device).long(),
                "attention_mask_q": batch["q_pos"][1].to(args.device).long(),
                "input_ids_a": batch["d_pos"][0].to(args.device).long(),
                "attention_mask_a": batch["d_pos"][1].to(args.device).long(),
                "input_ids_b": batch["d_neg"][0].to(args.device).long(),
                "attention_mask_b": batch["d_neg"][1].to(args.device).long(),
            }
            if args.dual_training:
                inputs["neg_query_ids"] = batch["q_neg"][0].to(
                    args.device).long()
                inputs["attention_mask_neg_query"] = batch["q_neg"][1].to(
                    args.device).long()
                inputs["prime_loss_weight"] = args.prime_loss_weight
                inputs["dual_loss_weight"] = args.dual_loss_weight

        inputs["temperature"] = args.temperature
        inputs["loss_objective"] = args.loss_objective_function

        if is_hypersphere_training:
            inputs["alignment_weight"] = args.hyper_align_weight
            inputs["uniformity_weight"] = args.hyper_unif_weight

        step += 1

        if args.local_rank != -1:
            # sync gradients only at gradient accumulation step
            if step % args.gradient_accumulation_steps == 0:
                outputs = model(**inputs)
            else:
                with model.no_sync():
                    outputs = model(**inputs)
        else:
            outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        loss_item_dict = outputs[1]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
            for k in loss_item_dict:
                loss_item_dict[k] = loss_item_dict[k].mean()

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps
            for k in loss_item_dict:
                loss_item_dict[
                    k] = loss_item_dict[k] / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if args.local_rank != -1:
                if step % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()
            else:
                loss.backward()

        def incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss):
            for k in loss_item_dict:
                if k not in tr_loss_dict:
                    tr_loss_dict[k] = loss_item_dict[k].item()
                else:
                    tr_loss_dict[k] += loss_item_dict[k].item()
            if "loss_total" not in tr_loss_dict:
                tr_loss_dict["loss_total"] = total_loss.item()
            else:
                tr_loss_dict["loss_total"] += total_loss.item()
            return tr_loss_dict

        tr_loss_dict = incremental_tr_loss(tr_loss_dict,
                                           loss_item_dict,
                                           total_loss=loss)

        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                learning_rate_scalar = scheduler.get_lr()[0]

                logs["learning_rate"] = learning_rate_scalar

                for k in tr_loss_dict:
                    logs[k] = tr_loss_dict[k] / args.logging_steps
                tr_loss_dict = {}

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step