Example #1
0
def load_model(args, checkpoint_path):
    label_list = ["0", "1"]
    num_labels = len(label_list)
    args.model_type = args.model_type.lower()
    configObj = MSMarcoConfigDict[args.model_type]
    args.model_name_or_path = checkpoint_path
    #print(checkpoint_path)

    model = configObj.model_class(args)

    saved_state = load_states_from_checkpoint(checkpoint_path)
    model_to_load = get_model_obj(model)
    logger.info('Loading saved model state ...')
    model_to_load.load_state_dict(saved_state.model_dict)

    model.to(args.device)
    logger.info("Inference parameters %s", args)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
    return model
Example #2
0
def load_model(args, checkpoint_path,load_flag=False):
    label_list = ["0", "1"]
    num_labels = len(label_list)
    args.model_type = args.model_type.lower()
    configObj = MSMarcoConfigDict[args.model_type]
    args.model_name_or_path = checkpoint_path

    model = configObj.model_class(args)

    if args.init_from_fp16_ckpt:
        checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')
        init_step = args.pretrained_checkpoint_dir.split('-')[-1].replace('/','')
        load_flag = checkpoint_step > init_step

    if args.fp16 and load_flag:
        checkpoint = torch.load(checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        new_state_dict = OrderedDict()
        for k, v in checkpoint['model'].items():
            name = k[7:]
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        saved_state = load_states_from_checkpoint(checkpoint_path)
        model_to_load = get_model_obj(model)
        logger.info('Loading saved model state ...')
        model_to_load.load_state_dict(saved_state.model_dict)
    
    model.is_representation_l2_normalization = args.representation_l2_normalization
    
    model.to(args.device)
    logger.info("Inference parameters %s", args)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )
    return model
Example #3
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(
        1, args.n_gpu)  #nll loss for query
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer = get_optimizer(
        args,
        model,
        weight_decay=args.weight_decay,
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0
    iter_count = 0

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps)

    global_step = 0
    if args.model_name_or_path != "bert-base-uncased":
        saved_state = load_states_from_checkpoint(args.model_name_or_path)
        global_step = _load_saved_state(model, optimizer, scheduler,
                                        saved_state)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

        #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
        #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
        #if is_first_worker():
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
    print(args.num_epoch)
    #step = global_step
    print(step, args.max_steps, global_step)

    global_step = 0
    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:

            if args.num_epoch == 0:
                #print('yes')
                # check if new ann training data is availabe
                ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
                #print(ann_path)
                #print(ann_no)
                #print(ndcg_json)
                if ann_path is not None and ann_no != last_ann_no:
                    logger.info("Training on new add data at %s", ann_path)
                    time.sleep(180)
                    with open(ann_path, 'r') as f:
                        #print(ann_path)
                        ann_training_data = f.readlines()
                    logger.info("Training data line count: %d",
                                len(ann_training_data))
                    ann_training_data = [
                        l for l in ann_training_data
                        if len(l.split('\t')[2].split(',')) > 1
                    ]
                    logger.info("Filtered training data line count: %d",
                                len(ann_training_data))
                    #ann_checkpoint_path = ndcg_json['checkpoint']
                    #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                    aligned_size = (len(ann_training_data) //
                                    args.world_size) * args.world_size
                    ann_training_data = ann_training_data[:aligned_size]

                    logger.info("Total ann queries: %d",
                                len(ann_training_data))
                    if args.triplet:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTripletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset, batch_size=args.train_batch_size)
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset,
                            batch_size=args.train_batch_size * 2)
                    train_dataloader_iter = iter(train_dataloader)

                    # re-warmup
                    if not args.single_warmup:
                        scheduler = get_linear_schedule_with_warmup(
                            optimizer,
                            num_warmup_steps=args.warmup_steps,
                            num_training_steps=len(ann_training_data))

                    if args.local_rank != -1:
                        dist.barrier()

                    if is_first_worker():
                        # add ndcg at checkpoint step used instead of current step
                        #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no)
                        #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no)
                        #if 'top20_trivia' in ndcg_json:
                        #    tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no)
                        #    tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no)
                        if last_ann_no != -1:
                            tb_writer.add_scalar("epoch", last_ann_no,
                                                 global_step - 1)
                        tb_writer.add_scalar("epoch", ann_no, global_step)
                    last_ann_no = ann_no
            elif step == 0:

                train_data_path = os.path.join(args.data_dir, "train-data")
                with open(train_data_path, 'r') as f:
                    training_data = f.readlines()
                if args.triplet:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size)
                else:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size * 2)
                all_batch = [b for b in train_dataloader]
                logger.info("Total batch count: %d", len(all_batch))
                train_dataloader_iter = iter(train_dataloader)

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            if args.num_epoch != 0:
                iter_count += 1
                if is_first_worker():
                    tb_writer.add_scalar("epoch", iter_count - 1,
                                         global_step - 1)
                    tb_writer.add_scalar("epoch", iter_count, global_step)
            #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
            #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
            #if is_first_worker():
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None:
                with open(ann_path, 'r') as f:
                    print(ann_path)
                    ann_training_data = f.readlines()
                logger.info("Training data line count: %d",
                            len(ann_training_data))
                ann_training_data = [
                    l for l in ann_training_data
                    if len(l.split('\t')[2].split(',')) > 1
                ]
                logger.info("Filtered training data line count: %d",
                            len(ann_training_data))

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]
                train_dataset = StreamingDataset(
                    ann_training_data,
                    GetTrainingDataProcessingFn(args, query_cache,
                                                passage_cache))
                train_dataloader = DataLoader(
                    train_dataset, batch_size=args.train_batch_size * 2)

            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)
            dist.barrier()

        if args.num_epoch != 0 and iter_count > args.num_epoch:
            break

        step += 1
        if args.triplet:
            loss = triplet_fwd_pass(args, model, batch)
        else:
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                _save_checkpoint(args, model, optimizer, scheduler,
                                 global_step)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step