Exemplo n.º 1
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})

    if len(optimizer_grouped_parameters) == 0:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            },
        ]

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    #logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)
                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                dev_ndcg = ndcg_json['ndcg']
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info("Total ann queries: %d", len(ann_training_data))
                if args.triplet:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                else:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data))

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    tb_writer.add_scalar("dev_ndcg", dev_ndcg,
                                         ann_checkpoint_no)
                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device) for t in batch)
        step += 1

        if args.triplet:
            inputs = {
                "query_ids": batch[0].long(),
                "attention_mask_q": batch[1].long(),
                "input_ids_a": batch[3].long(),
                "attention_mask_a": batch[4].long(),
                "input_ids_b": batch[6].long(),
                "attention_mask_b": batch[7].long()
            }
        else:
            inputs = {
                "input_ids_a": batch[0].long(),
                "attention_mask_a": batch[1].long(),
                "input_ids_b": batch[3].long(),
                "attention_mask_b": batch[4].long(),
                "labels": batch[6]
            }

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(**inputs)
        else:
            with model.no_sync():
                outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 2
0
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}")

        start_time = time.time()
        train_loss = train(args, epoch, writer, model, train_dataset)
        valid_loss, em, f1 = valid(model, valid_dataset, writer, epoch)
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        metrics['train_losses'].append(train_loss)
        metrics['valid_losses'].append(valid_loss)
        metrics['ems'].append(em)
        metrics['f1s'].append(f1)

        if valid_loss < valid_loss_prev:
            state = {'epoch': epoch, 'model_state_dict': model.module.state_dict(),
                     'optimizer_state_dict': optimizer.state_dict()}
            fname = os.path.join(ckpt_dir, 'best_weights.pt'.format(epoch))
            torch.save(state, fname)
        else:
            lives -= 1
            if lives == 0:
                break
        valid_loss_prev = valid_loss
        pickle.dump(metrics, open(os.path.join(ckpt_dir, 'metrics.p'), 'wb'))
        print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
        print(f"Epoch valid loss: {valid_loss}")
        print(f"Epoch EM: {em}")
        print(f"Epoch F1: {f1}")
        print("====================================================================================")
Exemplo n.º 3
0
def train(args, model, tokenizer, train_dataloader):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
        #args.num_train_epochs = args.max_steps // (args.expected_train_size // args.gradient_accumulation_steps) + 1 
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs    

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in ["roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead"]:
         layer = getattr_recursive(model, layer_name)
         if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    optimizer_grouped_parameters.append({"params": [p for p in model.parameters() if p not in layer_optim_params]})
    if len(optimizer_grouped_parameters)==0:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]   

    if args.optimizer.lower()=="lamb":
        optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    elif args.optimizer.lower()=="adamw":
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    else:
        raise Exception("optimizer {0} not recognized! Can only be lamb or adamW".format(args.optimizer))
    
    if args.scheduler.lower()=="linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
        )
    elif args.scheduler.lower()=="cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception("Scheduler {0} not recognized! Can only be linear or cosine".format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        # if is_first_worker():
        #     op_state = torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))
        #     print([len(x['params']) for x in op_state['param_groups']])
        #     real_op_state = optimizer.state_dict()
        #     print([len(x['params']) for x in real_op_state['param_groups']])
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    #logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (args.expected_train_size // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model") 

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for m_epoch in train_iterator:
        #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in tqdm(enumerate(train_dataloader), desc="Iteration", disable=args.local_rank not in [-1, 0]):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                outputs = model(*batch)
            else:
                with model.no_sync():
                    outputs = model(*batch)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()          

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if is_first_worker() and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
                dist.barrier()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training and global_step % (args.logging_steps_per_eval*args.logging_steps)==0:
                        model.eval()
                        reranking_mrr, full_ranking_mrr = passage_dist_eval(args, model, tokenizer)
                        if is_first_worker():
                            print("Reranking/Full ranking mrr: {0}/{1}".format(str(reranking_mrr), str(full_ranking_mrr)))
                            mrr_dict = {"reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr)}
                            tb_writer.add_scalars("mrr", mrr_dict, global_step)
                            print(args.output_dir)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss


                    if is_first_worker():
                        for key, value in logs.items():
                            print(key, type(value))
                            tb_writer.add_scalar(key, value, global_step)
                        tb_writer.add_scalar("epoch", m_epoch, global_step)
                        print(json.dumps({**logs, **{"step": global_step}}))
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 4
0
def train(args, train_dataset, model_d, model_g, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_d_grouped_parameters = [
        {
            "params": [
                p for n, p in model_d.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model_d.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    # optimizer_d = AdamW(optimizer_d_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    optimizer_d = Lamb(optimizer_d_grouped_parameters,
                       lr=args.learning_rate,
                       betas=(0.9, 0.999),
                       eps=1e-6)
    scheduler_d = get_linear_schedule_with_warmup(
        optimizer_d,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    optimizer_g_grouped_parameters = [
        {
            "params": [
                p for n, p in model_g.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model_g.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    # optimizer_g = AdamW(optimizer_g_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    optimizer_g = Lamb(optimizer_g_grouped_parameters,
                       lr=args.learning_rate,
                       betas=(0.9, 0.999),
                       eps=1e-6)
    scheduler_g = get_linear_schedule_with_warmup(
        optimizer_g,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer_d.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler_d.pt")):
        # Load in optimizer and scheduler states
        optimizer_d.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "optimizer_d.pt")))
        scheduler_d.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "scheduler_d.pt")))
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer_g.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler_g.pt")):
        # Load in optimizer and scheduler states
        optimizer_g.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "optimizer_g.pt")))
        scheduler_g.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "scheduler_g.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model_d, optimizer_d = amp.initialize(model_d,
                                              optimizer_d,
                                              opt_level=args.fp16_opt_level)
        model_g, optimizer_g = amp.initialize(model_g,
                                              optimizer_g,
                                              opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_d = torch.nn.DataParallel(model_d)
        model_g = torch.nn.DataParallel(model_g)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_d = torch.nn.parallel.DistributedDataParallel(
            model_d,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
        model_g = torch.nn.parallel.DistributedDataParallel(
            model_g,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    model_to_resize_d = model_d.module if hasattr(
        model_d,
        "module") else model_d  # Take care of distributed/parallel training
    # model_to_resize_d.resize_token_embeddings(len(tokenizer))
    model_to_resize_g = model_g.module if hasattr(
        model_g,
        "module") else model_g  # Take care of distributed/parallel training
    # model_to_resize_g.resize_token_embeddings(len(tokenizer))

    # model_to_resize_d.bert.embeddings = model_to_resize_g.bert.embeddings

    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_d, logging_loss_d = 0.0, 0.0
    tr_loss_g, logging_loss_g = 0.0, 0.0
    model_d.zero_grad()
    model_g.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model_d.train()
            model_g.train()
            # batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            # outputs = model(**inputs)
            # loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            masked_input_ids, mask_labels = mask_tokens(
                inputs['input_ids'], tokenizer, args)
            outputs_g = model_g(
                input_ids=masked_input_ids.to(args.device),
                masked_lm_labels=mask_labels.to(args.device),
                attention_mask=inputs['attention_mask'].to(args.device),
                token_type_ids=inputs['token_type_ids'].to(args.device))
            masked_lm_loss, prediction_scores_g = outputs_g[0], outputs_g[1]

            prediction_g = prediction_scores_g.max(dim=-1)[1].cpu()
            acc_g = (prediction_g[mask_labels >= 0] == mask_labels[
                mask_labels >= 0]).float().mean().item()

            prediction_probs_g = F.softmax(prediction_scores_g, dim=-1).cpu()
            bsz, seq_len, vocab_size = prediction_probs_g.size()
            prediction_samples_g = torch.multinomial(prediction_probs_g.view(
                -1, vocab_size),
                                                     num_samples=1)
            prediction_samples_g = prediction_samples_g.view(bsz, seq_len)
            input_ids_replace = inputs['input_ids'].clone()
            input_ids_replace[mask_labels >= 0] = prediction_samples_g[
                mask_labels >= 0]
            labels_d = input_ids_replace.eq(inputs['input_ids']).long()

            special_tokens_mask = [
                tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True)
                for val in inputs['input_ids'].tolist()
            ]
            labels_d.masked_fill_(torch.tensor(special_tokens_mask,
                                               dtype=torch.bool),
                                  value=-100)
            padding_mask = inputs['input_ids'].eq(tokenizer.pad_token_id)
            labels_d.masked_fill_(padding_mask, value=-100)

            labels_d_ones = labels_d[labels_d >= 0].float().mean().item()
            acc_replace = 1 - ((labels_d == 0).sum().float() /
                               (mask_labels >= 0).sum().float()).item()

            outputs_d = model_d(
                input_ids=input_ids_replace.to(args.device),
                attention_mask=inputs['attention_mask'].to(args.device),
                token_type_ids=inputs['token_type_ids'].to(args.device),
                labels=labels_d.to(args.device))
            loss_d, prediction_scores_d = outputs_d[0], outputs_d[1]
            prediction_d = prediction_scores_d.max(dim=-1)[1].cpu()
            acc_d = (prediction_d[labels_d >= 0] == labels_d[labels_d >= 0]
                     ).float().mean().item()
            acc_d_0 = (prediction_d[labels_d == 0] == labels_d[labels_d == 0]
                       ).float().mean().item()
            acc_d_1 = (prediction_d[labels_d == 1] == labels_d[labels_d == 1]
                       ).float().mean().item()

            if args.n_gpu > 1:
                loss_d = loss_d.mean(
                )  # mean() to average on multi-gpu parallel training
                masked_lm_loss = masked_lm_loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss_d = loss_d / args.gradient_accumulation_steps
                masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps

            lambd = 50
            loss = loss_d * lambd + masked_lm_loss
            if args.fp16:
                loss_d = loss_d * lambd
                with amp.scale_loss(loss_d, optimizer_d) as scaled_loss_d:
                    scaled_loss_d.backward()
                with amp.scale_loss(masked_lm_loss,
                                    optimizer_g) as scaled_loss_g:
                    scaled_loss_g.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            tr_loss_d += loss_d.item()
            tr_loss_g += masked_lm_loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer_d), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer_g), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_d.parameters(),
                                                   args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model_g.parameters(),
                                                   args.max_grad_norm)

                optimizer_d.step()
                scheduler_d.step()  # Update learning rate schedule
                model_d.zero_grad()
                optimizer_g.step()
                scheduler_g.step()  # Update learning rate schedule
                model_g.zero_grad()

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    # if (
                    #     args.local_rank == -1 and args.evaluate_during_training
                    # ):  # Only evaluate when single GPU otherwise metrics may not average well
                    #     results = evaluate(args, model, tokenizer)
                    #     for key, value in results.items():
                    #         eval_key = "eval_{}".format(key)
                    #         logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    loss_scalar_d = (tr_loss_d -
                                     logging_loss_d) / args.logging_steps
                    loss_scalar_g = (tr_loss_g -
                                     logging_loss_g) / args.logging_steps
                    learning_rate_scalar_d = scheduler_d.get_lr()[0]
                    learning_rate_scalar_g = scheduler_g.get_lr()[0]
                    logs["learning_rate_d"] = learning_rate_scalar_d
                    logs["learning_rate_g"] = learning_rate_scalar_g
                    logs["loss"] = loss_scalar
                    logs["loss_d"] = loss_scalar_d
                    logs["loss_g"] = loss_scalar_g
                    logs["acc_repalce"] = acc_replace
                    logs["acc_d"] = acc_d
                    logs["acc_d_0"] = acc_d_0
                    logs["acc_d_1"] = acc_d_1
                    logs["acc_g"] = acc_g
                    logs["labels_d_ones"] = labels_d_ones
                    logs["masked_ratio"] = (mask_labels >= 0).float().sum(
                    ).item() / (labels_d >= 0).sum().float().item()
                    logging_loss = tr_loss
                    logging_loss_d = tr_loss_d
                    logging_loss_g = tr_loss_g

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                # print(args.save_steps)
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    output_dir_d = os.path.join(
                        output_dir, "checkpoint-d-{}".format(global_step))
                    output_dir_g = os.path.join(
                        output_dir, "checkpoint-g-{}".format(global_step))
                    if not os.path.exists(output_dir_d):
                        os.makedirs(output_dir_d)
                    if not os.path.exists(output_dir_g):
                        os.makedirs(output_dir_g)
                    model_to_save_d = (
                        model_d.module if hasattr(model_d, "module") else
                        model_d)  # Take care of distributed/parallel training
                    model_to_save_g = (
                        model_g.module if hasattr(model_g, "module") else
                        model_g)  # Take care of distributed/parallel training
                    model_to_save_d.save_pretrained(output_dir_d)
                    model_to_save_g.save_pretrained(output_dir_g)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer_d.state_dict(),
                               os.path.join(output_dir_d, "optimizer_d.pt"))
                    torch.save(scheduler_d.state_dict(),
                               os.path.join(output_dir_d, "scheduler_d.pt"))
                    torch.save(optimizer_g.state_dict(),
                               os.path.join(output_dir_d, "optimizer_g.pt"))
                    torch.save(scheduler_g.state_dict(),
                               os.path.join(output_dir_d, "scheduler_g.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step