def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): """ Train the model """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter(os.path.join(args.output_dir,'tfboard')) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay, }, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, \ eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2)) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total ) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join(args.model_name_or_path, "scheduler.pt") ): # Load in optimizer and scheduler states optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] ) set_seed(args) # Added here for reproductibility best_dev, best_test = [0, 0, 0], [0, 0, 0] if args.mt: teacher_model = model for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device) for t in batch) inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} #inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[4]} if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use segment_ids outputs = model(**inputs) loss, logits, final_embeds = outputs[0], outputs[1], outputs[2] # model outputs are always tuple in pytorch-transformers (see doc) mt_loss, vat_loss = 0, 0 # Mean teacher training scheme if args.mt and global_step % args.mt_updatefreq == 0: update_step = global_step // args.mt_updatefreq if update_step == 1: teacher_model = copy.deepcopy(model) teacher_model.train(True) elif update_step < args.mt_rampup: alpha = args.mt_alpha1 else: alpha = args.mt_alpha2 mt_update(teacher_model.named_parameters(), model.named_parameters(), args.mt_avg, alpha, update_step) if args.mt and update_step > 0: with torch.no_grad(): teacher_outputs = teacher_model(**inputs) teacher_logits, teacher_final_embeds = teacher_outputs[1], teacher_outputs[2] _lambda = args.mt_lambda if args.mt_class != 'smart': _lambda = args.mt_lambda * min(1,math.exp(-5*(1-update_step/args.mt_rampup)**2)) if args.mt_loss_type == "embeds": mt_loss = get_mt_loss(final_embeds, teacher_final_embeds.detach(), args.mt_class, _lambda) else: mt_loss = get_mt_loss(logits, teacher_logits.detach(), args.mt_class, _lambda) # Virtual adversarial training if args.vat: if args.model_type in ["roberta", "camembert", "xlmroberta"]: word_embed = model.roberta.get_input_embeddings() elif args.model_type == "bert": word_embed = model.bert.get_input_embeddings() elif args.model_type == "distilbert": word_embed = model.distilbert.get_input_embeddings() if not word_embed: print("Model type not supported. Unable to retrieve word embeddings.") else: embeds = word_embed(batch[0]) vat_embeds = (embeds.data.detach() + embeds.data.new(embeds.size()).normal_(0, 1)*1e-5).detach() vat_embeds.requires_grad_() vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use segment_ids vat_outputs = model(**vat_inputs) vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[2] if args.vat_loss_type == "embeds": vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, 1) else: vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, 1) vat_embeds.grad = opt_grad(vat_loss, vat_embeds, optimizer)[0] norm = vat_embeds.grad.norm() if (torch.isnan(norm) or torch.isinf(norm)): print("Hit nan gradient in embed vat") else: adv_direct = vat_embeds.grad / (vat_embeds.grad.abs().max(-1, keepdim=True)[0]+1e-4) vat_embeds = vat_embeds + args.vat_eps * adv_direct vat_embeds = vat_embeds.detach() vat_inputs = {"inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3]} if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use segment_ids vat_outputs = model(**vat_inputs) vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[2] if args.vat_loss_type == "embeds": vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, args.vat_lambda) \ + get_mt_loss(final_embeds, vat_final_embeds.detach(), args.mt_class, args.vat_lambda) else: vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, args.vat_lambda) \ + get_mt_loss(logits, vat_logits.detach(), args.mt_class, args.vat_lambda) loss = loss + args.mt_beta * mt_loss + args.vat_beta * vat_loss if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.evaluate_during_training: logger.info("***** Entropy loss: %.4f, mean teacher loss : %.4f; vat loss: %.4f *****", \ loss - args.mt_beta * mt_loss - args.vat_beta * vat_loss, \ args.mt_beta * mt_loss, args.vat_beta * vat_loss) results, _, best_dev, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, best_dev, mode="dev", prefix='dev [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) results, _, best_test, is_updated = evaluate(args, model, tokenizer, labels, pad_token_label_id, best_test, mode="test", prefix='test [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False) for key, value in results.items(): tb_writer.add_scalar("test_{}".format(key), value, global_step) output_dirs = [] if args.local_rank in [-1, 0] and is_updated: output_dirs.append(os.path.join(args.output_dir, "checkpoint-best")) if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: output_dirs.append(os.path.join(args.output_dir, "checkpoint-{}".format(global_step))) if len(output_dirs) > 0: for output_dir in output_dirs: logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) torch.save(model.state_dict(), os.path.join(output_dir, "model.pt")) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return global_step, tr_loss / global_step, best_dev, best_test
def train(args, train_dataset, model_class, config, tokenizer, labels, pad_token_label_id): """ 训练模型 :param args: argparse参数 :param train_dataset: 训练集Dataset :param model_class: 加载好的model :param config: model配置 :param tokenizer: 加载好的tokenizer :param labels: 所有的labels, eg: ['O', 'B-LOC', 'B-ORG', 'B-PER', 'B-MISC', 'I-PER', 'I-MISC', 'I-ORG', 'I-LOC', '<START>', '<STOP>'] :param pad_token_label_id: pad token对应的label的id eg:-100 :return: """ if args.local_rank in [-1, 0]: tb_writer = SummaryWriter(os.path.join(args.output_dir, 'tfboard')) #计算batch_size args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) # 随机采样的方式 train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) # 定义Dataloader,设置采样方式和batch_size train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) #计算总的steps if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs model, optimizer, scheduler = initialize(args, model_class, config, t_total, 0) # Train! logger.info("***** 开始训练 *****") logger.info(" 样本总数 = %d", len(train_dataset)) logger.info(" Epochs总数 = %d", args.num_train_epochs) logger.info(" 每个GPU的Batch size = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" 梯度累积步数 = %d", args.gradient_accumulation_steps) logger.info(" 总步数 = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # 检查是否从一个checkpoint继续训练, 重新设置global_step,epochs_trained,steps_trained_in_current_epoch # 需要你把自动加载model_name_or_path里面的模型 if os.path.exists(args.model_name_or_path): # 从model_name_or_path获取global_step global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) #计算已经训练了多少epochs epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) #计算当前是第多少个训练epoch steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) tr_loss, logging_loss = 0.0, 0.0 #总的Epoch进度条 train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproductibility best_dev, best_test = [0, 0, 0], [0, 0, 0] if args.mt: teacher_model = model self_training_teacher_model = model for epoch in train_iterator: # 每个epoch的进度条 epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue #设置模型为train model.train() # 放到GPU batch = tuple(t.to(args.device) for t in batch) # 在一定步骤之后定期更新label if global_step >= args.self_training_begin_step: # 定期更新一个新的teacher模型 delta = global_step - args.self_training_begin_step if delta % args.self_training_period == 0: # 满足更新条件,开始更新,拷贝一个模型作为教师模型 self_training_teacher_model = copy.deepcopy(model) #教师模型设置为评估 self_training_teacher_model.eval() # 获得新teacher后,重新初始化student模型 if args.self_training_reinit: model, optimizer, scheduler = initialize( args, model_class, config, t_total, epoch) # 使用当前的teacher更新label inputs = {"input_ids": batch[0], "attention_mask": batch[1]} with torch.no_grad(): # outputs: (loss), logits, final_embedding, (hidden_states), (attentions) outputs = self_training_teacher_model(**inputs) label_mask = None if args.self_training_label_mode == "hard": #直接用最大值位置索引作为硬标签 pred_labels = torch.argmax(outputs[0], axis=2) pred_labels, label_mask = multi_source_label_refine( args, batch[5], batch[3], pred_labels, pad_token_label_id, pred_logits=outputs[0]) elif args.self_training_label_mode == "soft": #计算软标签 pred_labels = soft_frequency(logits=outputs[0], power=2) # combined_labels 用的是真实的labels, 根据self_training_hp_label 计算 pred_labels, label_mask pred_labels, label_mask = multi_source_label_refine( args=args, hp_labels=batch[5], combined_labels=batch[3], pred_labels=pred_labels, pad_token_label_id=pad_token_label_id) # 使用teacher模型的输出pred_labels和label_mask作为我们模型的输入 inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": pred_labels, "label_mask": label_mask } else: inputs = { "input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3] } # 如果不是distilbert,那么需要使用segment_ids,这里是token_type_ids作为key if args.model_type != "distilbert": inputs["token_type_ids"] = (batch[2] if args.model_type in ["bert", "xlnet"] else None) #输入到模型 outputs = model(**inputs) # 损失,logits和final_embeds,final_embeds是transformers 的roberta的输出 loss, logits, final_embeds = outputs[0], outputs[1], outputs[ 2] # model outputs are always tuple in pytorch-transformers (see doc) mt_loss, vat_loss = 0, 0 # Mean teacher training scheme, 使用mean teacher的方法 if args.mt and global_step % args.mt_updatefreq == 0: update_step = global_step // args.mt_updatefreq if update_step == 1: teacher_model = copy.deepcopy(model) teacher_model.train(True) elif update_step < args.mt_rampup: alpha = args.mt_alpha1 else: alpha = args.mt_alpha2 mt_update(teacher_model.named_parameters(), model.named_parameters(), args.mt_avg, alpha, update_step) if args.mt and update_step > 0: with torch.no_grad(): teacher_outputs = teacher_model(**inputs) teacher_logits, teacher_final_embeds = teacher_outputs[ 1], teacher_outputs[2] _lambda = args.mt_lambda if args.mt_class != 'smart': _lambda = args.mt_lambda * min( 1, math.exp(-5 * (1 - update_step / args.mt_rampup)**2)) if args.mt_loss_type == "embeds": mt_loss = get_mt_loss(final_embeds, teacher_final_embeds.detach(), args.mt_class, _lambda) else: mt_loss = get_mt_loss(logits, teacher_logits.detach(), args.mt_class, _lambda) # Virtual adversarial training, 使用VAT的方法 if args.vat: if args.model_type in ["roberta", "camembert", "xlmroberta"]: word_embed = model.roberta.get_input_embeddings() elif args.model_type == "bert": word_embed = model.bert.get_input_embeddings() elif args.model_type == "distilbert": word_embed = model.distilbert.get_input_embeddings() if not word_embed: print( "Model type not supported. Unable to retrieve word embeddings." ) else: embeds = word_embed(batch[0]) vat_embeds = (embeds.data.detach() + embeds.data.new( embeds.size()).normal_(0, 1) * 1e-5).detach() vat_embeds.requires_grad_() vat_inputs = { "inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet"] else None) # XLM and RoBERTa don"t use segment_ids vat_outputs = model(**vat_inputs) vat_logits, vat_final_embeds = vat_outputs[1], vat_outputs[ 2] if args.vat_loss_type == "embeds": vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, 1) else: vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, 1) # 优化梯度 vat_embeds.grad = opt_grad(vat_loss, vat_embeds, optimizer)[0] norm = vat_embeds.grad.norm() if (torch.isnan(norm) or torch.isinf(norm)): print("Hit nan gradient in embed vat") else: adv_direct = vat_embeds.grad / ( vat_embeds.grad.abs().max(-1, keepdim=True)[0] + 1e-4) vat_embeds = vat_embeds + args.vat_eps * adv_direct vat_embeds = vat_embeds.detach() vat_inputs = { "inputs_embeds": vat_embeds, "attention_mask": batch[1], "labels": batch[3] } if args.model_type != "distilbert": inputs["token_type_ids"] = ( batch[2] if args.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use segment_ids vat_outputs = model(**vat_inputs) vat_logits, vat_final_embeds = vat_outputs[ 1], vat_outputs[2] if args.vat_loss_type == "embeds": vat_loss = get_mt_loss(vat_final_embeds, final_embeds.detach(), args.mt_class, args.vat_lambda) \ + get_mt_loss(final_embeds, vat_final_embeds.detach(), args.mt_class, args.vat_lambda) else: vat_loss = get_mt_loss(vat_logits, logits.detach(), args.mt_class, args.vat_lambda) \ + get_mt_loss(logits, vat_logits.detach(), args.mt_class, args.vat_lambda) # 可以mt和vat一起使用,然后计算损失,也可以都不用 loss = loss + args.mt_beta * mt_loss + args.vat_beta * vat_loss if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps #混合精度 if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # 把损失取出来,计算总损失 tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # 满足一定step,就记录日志 if args.evaluate_during_training: logger.info("***** Entropy loss: %.4f, mean teacher loss : %.4f; vat loss: %.4f *****", \ loss - args.mt_beta * mt_loss - args.vat_beta * vat_loss, \ args.mt_beta * mt_loss, args.vat_beta * vat_loss) results, _, best_dev, _ = evaluate( args, model, tokenizer, labels, pad_token_label_id, best_dev, mode="dev", prefix='dev [Step {}/{} | Epoch {}/{}]'.format( global_step, t_total, epoch, args.num_train_epochs), verbose=False) for key, value in results.items(): tb_writer.add_scalar("eval_{}".format(key), value, global_step) results, _, best_test, is_updated = evaluate( args, model, tokenizer, labels, pad_token_label_id, best_test, mode="test", prefix='test [Step {}/{} | Epoch {}/{}]'.format( global_step, t_total, epoch, args.num_train_epochs), verbose=False) for key, value in results.items(): tb_writer.add_scalar("test_{}".format(key), value, global_step) output_dirs = [] if args.local_rank in [-1, 0] and is_updated: updated_self_training_teacher = True output_dirs.append( os.path.join(args.output_dir, "checkpoint-best")) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: output_dirs.append( os.path.join( args.output_dir, "checkpoint-{}".format(global_step))) if len(output_dirs) > 0: for output_dir in output_dirs: logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save( args, os.path.join(output_dir, "training_args.bin")) torch.save( model.state_dict(), os.path.join(output_dir, "model.pt")) torch.save( optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save( scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info( "Saving optimizer and scheduler states to %s", output_dir) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) logging_loss = tr_loss # 判断是否迭代完成 if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break #判断epoch是否迭代完成 if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank in [-1, 0]: tb_writer.close() return model, global_step, tr_loss / global_step, best_dev, best_test