Esempio n. 1
0
def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(os.path.join(args.output_dir,'tfboard'))

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproductibility
    best_dev, best_test = [0, 0, 0], [0, 0, 0]
    if args.mt:
        teacher_model = model
    

    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            #inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[4]}
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["bert", "xlnet"] else None
                )  # XLM and RoBERTa don"t use segment_ids

            outputs = model(**inputs)
            loss, logits, final_embeds = outputs[0], outputs[1], outputs[2] # model outputs are always tuple in pytorch-transformers (see doc)
            mt_loss, vat_loss = 0, 0

            # Mean teacher training scheme
            if args.mt and global_step % args.mt_updatefreq == 0:
                update_step = global_step // args.mt_updatefreq
                if update_step == 1:
                    teacher_model = copy.deepcopy(model)
                    teacher_model.train(True)
                elif update_step < args.mt_rampup:
                    alpha = args.mt_alpha1
                else:
                    alpha = args.mt_alpha2
                mt_update(teacher_model.named_parameters(), model.named_parameters(), args.mt_avg, alpha, update_step)

            if args.mt and update_step > 0:
                with torch.no_grad():
                    teacher_outputs = teacher_model(**inputs)
                    teacher_logits, teacher_final_embeds = teacher_outputs[1], teacher_outputs[2]

                _lambda = args.mt_lambda
                if args.mt_class != 'smart':
                    _lambda = args.mt_lambda * min(1,math.exp(-5*(1-update_step/args.mt_rampup)**2))
                
                if args.mt_loss_type == "embeds":
                    mt_loss = get_mt_loss(final_embeds, teacher_final_embeds.detach(), args.mt_class, _lambda)
                else:
                    mt_loss = get_mt_loss(logits, teacher_logits.detach(), args.mt_class, _lambda)

            # Virtual adversarial training
            if args.vat:
                
                if args.model_type in ["roberta", "camembert", "xlmroberta"]:
                    word_embed = model.roberta.get_input_embeddings()
                elif args.model_type == "bert":
                    word_embed = model.bert.get_input_embeddings()
                elif args.model_type == "distilbert":
                    word_embed = model.distilbert.get_input_embeddings()

                if not word_embed:
                    print("Model type not supported. Unable to retrieve word embeddings.")
                else:
                    embeds = word_embed(batch[0])
                    vat_embeds = (embeds.data.detach() + embeds.data.new(embeds.size()).normal_(0, 1)*1e-5).detach()
                    vat_embeds.requires_grad_()
                    
                    vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]}
                    if args.model_type != "distilbert":
                        inputs["token_type_ids"] = (
                            batch[2] if args.model_type in ["bert", "xlnet"] else None
                        )  # XLM and RoBERTa don"t use segment_ids
                    
                    vat_outputs = model(**vat_inputs)
                    vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[2]

                    if args.vat_loss_type == "embeds":
                        vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, 1)
                    else:
                        vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, 1)
                    
                    vat_embeds.grad = opt_grad(vat_loss, vat_embeds, optimizer)[0]
                    norm = vat_embeds.grad.norm()

                    if (torch.isnan(norm) or torch.isinf(norm)):
                        print("Hit nan gradient in embed vat")
                    else:
                        adv_direct = vat_embeds.grad / (vat_embeds.grad.abs().max(-1, keepdim=True)[0]+1e-4)
                        vat_embeds = vat_embeds + args.vat_eps * adv_direct
                        vat_embeds = vat_embeds.detach()
                        
                        vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]}
                        if args.model_type != "distilbert":
                            inputs["token_type_ids"] = (
                                batch[2] if args.model_type in ["bert", "xlnet"] else None
                            )  # XLM and RoBERTa don"t use segment_ids
                        
                        vat_outputs = model(**vat_inputs)
                        vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[2]
                        if args.vat_loss_type == "embeds":
                            vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, args.vat_lambda) \
                                    + get_mt_loss(final_embeds, vat_final_embeds.detach(), args.mt_class, args.vat_lambda)
                        else:
                            vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, args.vat_lambda) \
                                    + get_mt_loss(logits, vat_logits.detach(), args.mt_class, args.vat_lambda)
            
            loss = loss + args.mt_beta * mt_loss + args.vat_beta * vat_loss
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.evaluate_during_training:

                        logger.info("***** Entropy loss: %.4f, mean teacher loss : %.4f; vat loss: %.4f *****", \
                            loss - args.mt_beta * mt_loss - args.vat_beta * vat_loss, \
                            args.mt_beta * mt_loss, args.vat_beta * vat_loss)
                        
                        results, _, best_dev, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, best_dev, mode="dev", prefix='dev [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)

                        results, _, best_test, is_updated  = evaluate(args, model, tokenizer, labels, pad_token_label_id, best_test, mode="test", prefix='test [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
                        for key, value in results.items():
                            tb_writer.add_scalar("test_{}".format(key), value, global_step)

                        output_dirs = []
                        if args.local_rank in [-1, 0] and is_updated:
                            output_dirs.append(os.path.join(args.output_dir, "checkpoint-best"))

                        if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                            output_dirs.append(os.path.join(args.output_dir, "checkpoint-{}".format(global_step)))

                        if len(output_dirs) > 0:
                            for output_dir in output_dirs:
                                logger.info("Saving model checkpoint to %s", args.output_dir)
                                # Save a trained model, configuration and tokenizer using `save_pretrained()`.
                                # They can then be reloaded using `from_pretrained()`
                                if not os.path.exists(output_dir):
                                    os.makedirs(output_dir)
                                model_to_save = (
                                    model.module if hasattr(model, "module") else model
                                )  # Take care of distributed/parallel training
                                model_to_save.save_pretrained(output_dir)
                                tokenizer.save_pretrained(output_dir)
                                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                                torch.save(model.state_dict(), os.path.join(output_dir, "model.pt"))
                                torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                                logger.info("Saving optimizer and scheduler states to %s", output_dir)

                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss


            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, best_dev, best_test
Esempio n. 2
0
def train(args, train_dataset, model_class, config, tokenizer, labels,
          pad_token_label_id):
    """
    训练模型
    :param args: argparse参数
    :param train_dataset:  训练集Dataset
    :param model_class: 加载好的model
    :param config: model配置
    :param tokenizer: 加载好的tokenizer
    :param labels:  所有的labels, eg: ['O', 'B-LOC', 'B-ORG', 'B-PER', 'B-MISC', 'I-PER', 'I-MISC', 'I-ORG', 'I-LOC', '<START>', '<STOP>']
    :param pad_token_label_id: pad token对应的label的id eg:-100
    :return:
    """

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(os.path.join(args.output_dir, 'tfboard'))
    #计算batch_size
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # 随机采样的方式
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    # 定义Dataloader,设置采样方式和batch_size
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    #计算总的steps
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    model, optimizer, scheduler = initialize(args, model_class, config,
                                             t_total, 0)
    # Train!
    logger.info("***** 开始训练 *****")
    logger.info("  样本总数 = %d", len(train_dataset))
    logger.info("  Epochs总数 = %d", args.num_train_epochs)
    logger.info("  每个GPU的Batch size = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  梯度累积步数 = %d", args.gradient_accumulation_steps)
    logger.info("  总步数 = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # 检查是否从一个checkpoint继续训练,  重新设置global_step,epochs_trained,steps_trained_in_current_epoch
    # 需要你把自动加载model_name_or_path里面的模型
    if os.path.exists(args.model_name_or_path):
        # 从model_name_or_path获取global_step
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        #计算已经训练了多少epochs
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        #计算当前是第多少个训练epoch
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    #总的Epoch进度条
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility
    best_dev, best_test = [0, 0, 0], [0, 0, 0]
    if args.mt:
        teacher_model = model
    self_training_teacher_model = model

    for epoch in train_iterator:
        # 每个epoch的进度条
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            #设置模型为train
            model.train()
            # 放到GPU
            batch = tuple(t.to(args.device) for t in batch)
            # 在一定步骤之后定期更新label
            if global_step >= args.self_training_begin_step:

                # 定期更新一个新的teacher模型
                delta = global_step - args.self_training_begin_step
                if delta % args.self_training_period == 0:
                    # 满足更新条件,开始更新,拷贝一个模型作为教师模型
                    self_training_teacher_model = copy.deepcopy(model)
                    #教师模型设置为评估
                    self_training_teacher_model.eval()

                    # 获得新teacher后,重新初始化student模型
                    if args.self_training_reinit:
                        model, optimizer, scheduler = initialize(
                            args, model_class, config, t_total, epoch)

                # 使用当前的teacher更新label
                inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
                with torch.no_grad():
                    # outputs:  (loss), logits, final_embedding, (hidden_states), (attentions)
                    outputs = self_training_teacher_model(**inputs)
                label_mask = None
                if args.self_training_label_mode == "hard":
                    #直接用最大值位置索引作为硬标签
                    pred_labels = torch.argmax(outputs[0], axis=2)
                    pred_labels, label_mask = multi_source_label_refine(
                        args,
                        batch[5],
                        batch[3],
                        pred_labels,
                        pad_token_label_id,
                        pred_logits=outputs[0])
                elif args.self_training_label_mode == "soft":
                    #计算软标签
                    pred_labels = soft_frequency(logits=outputs[0], power=2)
                    # combined_labels 用的是真实的labels, 根据self_training_hp_label 计算 pred_labels, label_mask
                    pred_labels, label_mask = multi_source_label_refine(
                        args=args,
                        hp_labels=batch[5],
                        combined_labels=batch[3],
                        pred_labels=pred_labels,
                        pad_token_label_id=pad_token_label_id)
                # 使用teacher模型的输出pred_labels和label_mask作为我们模型的输入
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": pred_labels,
                    "label_mask": label_mask
                }
            else:
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": batch[3]
                }
            # 如果不是distilbert,那么需要使用segment_ids,这里是token_type_ids作为key
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (batch[2] if args.model_type
                                            in ["bert", "xlnet"] else None)
            #输入到模型
            outputs = model(**inputs)
            # 损失,logits和final_embeds,final_embeds是transformers 的roberta的输出
            loss, logits, final_embeds = outputs[0], outputs[1], outputs[
                2]  # model outputs are always tuple in pytorch-transformers (see doc)
            mt_loss, vat_loss = 0, 0

            # Mean teacher training scheme, 使用mean teacher的方法
            if args.mt and global_step % args.mt_updatefreq == 0:
                update_step = global_step // args.mt_updatefreq
                if update_step == 1:
                    teacher_model = copy.deepcopy(model)
                    teacher_model.train(True)
                elif update_step < args.mt_rampup:
                    alpha = args.mt_alpha1
                else:
                    alpha = args.mt_alpha2
                mt_update(teacher_model.named_parameters(),
                          model.named_parameters(), args.mt_avg, alpha,
                          update_step)

            if args.mt and update_step > 0:
                with torch.no_grad():
                    teacher_outputs = teacher_model(**inputs)
                    teacher_logits, teacher_final_embeds = teacher_outputs[
                        1], teacher_outputs[2]

                _lambda = args.mt_lambda
                if args.mt_class != 'smart':
                    _lambda = args.mt_lambda * min(
                        1, math.exp(-5 *
                                    (1 - update_step / args.mt_rampup)**2))

                if args.mt_loss_type == "embeds":
                    mt_loss = get_mt_loss(final_embeds,
                                          teacher_final_embeds.detach(),
                                          args.mt_class, _lambda)
                else:
                    mt_loss = get_mt_loss(logits, teacher_logits.detach(),
                                          args.mt_class, _lambda)

            # Virtual adversarial training, 使用VAT的方法
            if args.vat:

                if args.model_type in ["roberta", "camembert", "xlmroberta"]:
                    word_embed = model.roberta.get_input_embeddings()
                elif args.model_type == "bert":
                    word_embed = model.bert.get_input_embeddings()
                elif args.model_type == "distilbert":
                    word_embed = model.distilbert.get_input_embeddings()

                if not word_embed:
                    print(
                        "Model type not supported. Unable to retrieve word embeddings."
                    )
                else:
                    embeds = word_embed(batch[0])
                    vat_embeds = (embeds.data.detach() + embeds.data.new(
                        embeds.size()).normal_(0, 1) * 1e-5).detach()
                    vat_embeds.requires_grad_()

                    vat_inputs = {
                        "inputs_embeds": vat_embeds,
                        "attention_mask": batch[1],
                        "labels": batch[3]
                    }
                    if args.model_type != "distilbert":
                        inputs["token_type_ids"] = (
                            batch[2] if args.model_type in ["bert", "xlnet"]
                            else None)  # XLM and RoBERTa don"t use segment_ids

                    vat_outputs = model(**vat_inputs)
                    vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[
                        2]

                    if args.vat_loss_type == "embeds":
                        vat_loss = get_mt_loss(vat_final_embeds,
                                               final_embeds.detach(),
                                               args.mt_class, 1)
                    else:
                        vat_loss = get_mt_loss(vat_logits, logits.detach(),
                                               args.mt_class, 1)
                    # 优化梯度
                    vat_embeds.grad = opt_grad(vat_loss, vat_embeds,
                                               optimizer)[0]
                    norm = vat_embeds.grad.norm()

                    if (torch.isnan(norm) or torch.isinf(norm)):
                        print("Hit nan gradient in embed vat")
                    else:
                        adv_direct = vat_embeds.grad / (
                            vat_embeds.grad.abs().max(-1, keepdim=True)[0] +
                            1e-4)
                        vat_embeds = vat_embeds + args.vat_eps * adv_direct
                        vat_embeds = vat_embeds.detach()

                        vat_inputs = {
                            "inputs_embeds": vat_embeds,
                            "attention_mask": batch[1],
                            "labels": batch[3]
                        }
                        if args.model_type != "distilbert":
                            inputs["token_type_ids"] = (
                                batch[2] if args.model_type
                                in ["bert", "xlnet"] else None
                            )  # XLM and RoBERTa don"t use segment_ids

                        vat_outputs = model(**vat_inputs)
                        vat_logits, vat_final_embeds = vat_outputs[
                            1], vat_outputs[2]
                        if args.vat_loss_type == "embeds":
                            vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, args.vat_lambda) \
                                    + get_mt_loss(final_embeds, vat_final_embeds.detach(), args.mt_class, args.vat_lambda)
                        else:
                            vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, args.vat_lambda) \
                                    + get_mt_loss(logits, vat_logits.detach(), args.mt_class, args.vat_lambda)
            # 可以mt和vat一起使用,然后计算损失,也可以都不用
            loss = loss + args.mt_beta * mt_loss + args.vat_beta * vat_loss

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            #混合精度
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            # 把损失取出来,计算总损失
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # 满足一定step,就记录日志
                    if args.evaluate_during_training:

                        logger.info("***** Entropy loss: %.4f, mean teacher loss : %.4f; vat loss: %.4f *****", \
                            loss - args.mt_beta * mt_loss - args.vat_beta * vat_loss, \
                            args.mt_beta * mt_loss, args.vat_beta * vat_loss)

                        results, _, best_dev, _ = evaluate(
                            args,
                            model,
                            tokenizer,
                            labels,
                            pad_token_label_id,
                            best_dev,
                            mode="dev",
                            prefix='dev [Step {}/{} | Epoch {}/{}]'.format(
                                global_step, t_total, epoch,
                                args.num_train_epochs),
                            verbose=False)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)

                        results, _, best_test, is_updated = evaluate(
                            args,
                            model,
                            tokenizer,
                            labels,
                            pad_token_label_id,
                            best_test,
                            mode="test",
                            prefix='test [Step {}/{} | Epoch {}/{}]'.format(
                                global_step, t_total, epoch,
                                args.num_train_epochs),
                            verbose=False)
                        for key, value in results.items():
                            tb_writer.add_scalar("test_{}".format(key), value,
                                                 global_step)

                        output_dirs = []
                        if args.local_rank in [-1, 0] and is_updated:
                            updated_self_training_teacher = True
                            output_dirs.append(
                                os.path.join(args.output_dir,
                                             "checkpoint-best"))

                        if args.local_rank in [
                                -1, 0
                        ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                            output_dirs.append(
                                os.path.join(
                                    args.output_dir,
                                    "checkpoint-{}".format(global_step)))

                        if len(output_dirs) > 0:
                            for output_dir in output_dirs:
                                logger.info("Saving model checkpoint to %s",
                                            args.output_dir)
                                # Save a trained model, configuration and tokenizer using `save_pretrained()`.
                                # They can then be reloaded using `from_pretrained()`
                                if not os.path.exists(output_dir):
                                    os.makedirs(output_dir)
                                model_to_save = (
                                    model.module
                                    if hasattr(model, "module") else model
                                )  # Take care of distributed/parallel training
                                model_to_save.save_pretrained(output_dir)
                                tokenizer.save_pretrained(output_dir)
                                torch.save(
                                    args,
                                    os.path.join(output_dir,
                                                 "training_args.bin"))
                                torch.save(
                                    model.state_dict(),
                                    os.path.join(output_dir, "model.pt"))
                                torch.save(
                                    optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                                torch.save(
                                    scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                                logger.info(
                                    "Saving optimizer and scheduler states to %s",
                                    output_dir)

                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss
            # 判断是否迭代完成
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        #判断epoch是否迭代完成
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return model, global_step, tr_loss / global_step, best_dev, best_test