def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): # Initialise PyTorch model config = ElectraConfig.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) if discriminator_or_generator == "discriminator": model = ElectraForPreTraining(config) elif discriminator_or_generator == "generator": model = ElectraForMaskedLM(config) else: raise ValueError( "The discriminator_or_generator argument should be either 'discriminator' or 'generator'" ) # Load weights from tf checkpoint load_tf_weights_in_electra( model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def load_electra_model(self): parser = argparse.ArgumentParser() args = parser.parse_args() args.output_encoded_layers = True args.output_attention_layers = True args.output_att_score = True args.output_att_sum = True self.args = args # 解析配置文件, 教师模型和student模型的vocab是不变的 # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型 # student config: config/chinese_bert_config_L4t.json # distil student model: distil_model/gs8316.pkl bert_config_file_S = self.model_conf tuned_checkpoint_S = self.model_file # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度 bert_config_S = ElectraConfig.from_json_file(bert_config_file_S) bert_config_S.num_labels = self.num_labels # 加载tokenizer self.predict_tokenizer = BertTokenizer(vocab_file=self.vocab_file) # 加载模型 self.predict_model = ElectraSPC(bert_config_S) assert os.path.exists(tuned_checkpoint_S), "模型文件不存在,请检查" state_dict_S = torch.load(tuned_checkpoint_S, map_location=self.device) self.predict_model.load_state_dict(state_dict_S) if self.verbose: print("模型已加载") logger.info(f"预测模型{tuned_checkpoint_S}加载完成")
def main(): #parse arguments config.parse() args = config.args logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger = logging.getLogger("Main") #arguments check device, n_gpu = args_check(logger, args) if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) if args.local_rank != -1: logger.warning( f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}" ) for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.num_labels = len(label2id_dict) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_dataset = None eval_examples = None eval_dataset = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.local_rank not in [-1, 0]: torch.distributed.barrier() if args.do_train: train_examples, train_dataset = read_features( args.train_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.do_predict: eval_examples, eval_dataset = read_features( args.predict_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.local_rank == 0: torch.distributed.barrier() #Build Model and load checkpoint model_S = ElectraForTokenClassification(bert_config_S) #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_S.to(device) if args.do_train: #parameters if args.lr_decay is not None: outputs_params = list(model_S.classifier.named_parameters()) outputs_params = divide_parameters(outputs_params, lr=args.learning_rate) electra_params = [] n_layers = len(model_S.electra.encoder.layer) assert n_layers == 12 for i, n in enumerate(reversed(range(n_layers))): encoder_params = list( model_S.electra.encoder.layer[n].named_parameters()) lr = args.learning_rate * args.lr_decay**(i + 1) electra_params += divide_parameters(encoder_params, lr=lr) logger.info(f"{i},{n},{lr}") embed_params = [ (name, value) for name, value in model_S.electra.named_parameters() if 'embedding' in name ] logger.info(f"{[name for name,value in embed_params]}") lr = args.learning_rate * args.lr_decay**(n_layers + 1) electra_params += divide_parameters(embed_params, lr=lr) logger.info(f"embed lr:{lr}") all_trainable_params = outputs_params + electra_params assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\ (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters()))) else: params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) num_train_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias=False) if args.official_schedule == 'const': scheduler_class = get_constant_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps) } #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps)) elif args.official_schedule == 'linear': scheduler_class = get_linear_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps) elif args.official_schedule == 'const_nowarmup': scheduler_class = get_constant_schedule scheduler_args = {} else: raise NotImplementedError logger.warning("***** Running training *****") logger.warning("local_rank %d Num orig examples = %d", args.local_rank, len(train_examples)) logger.warning("local_rank %d Num split examples = %d", args.local_rank, len(train_dataset)) logger.warning("local_rank %d Forward batch size = %d", args.local_rank, forward_batch_size) logger.warning("local_rank %d Num backward steps = %d", args.local_rank, num_train_steps) ########### TRAINING ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device, fp16=args.fp16, local_rank=args.local_rank) logger.info(f"{train_config}") distiller = BasicTrainer( train_config=train_config, model=model_S, adaptor=ElectraForTokenClassificationAdaptorTraining) # evluate the model in a single process in ddp_predict callback_func = partial(ddp_predict, eval_examples=eval_examples, eval_dataset=eval_dataset, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, max_grad_norm=1.0, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = ddp_predict(model_S, eval_examples, eval_dataset, step=0, args=args) print(res)
def main(): #解析参数 config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #准备任务 processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) # Student的配置 if args.model_architecture == "electra": # 从transformers包中导入ElectraConfig, 并加载配置 bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings elif args.model_architecture == "albert": # 从transformers包中导入AlbertConfig, 并加载配置 bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings else: bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_dataset = None eval_datasets = None num_train_steps = None # electra和bert都使用的bert的 tokenizer方式 tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs logger.info("训练数据集已加载") if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("预测数据集已加载") # Student的配置 if args.model_architecture == "electra": #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = ElectraSPC(bert_config_S) elif args.model_architecture == "albert": model_S = AlbertSPC(bert_config_S) else: #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args) #对加载后的student模型的参数进行初始化, 使用student模型预测 if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) print(f"missing_keys,注意丢失的参数{missing_keys}") logger.info("Model loaded") elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") elif args.model_architecture in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") else: logger.info("Model is randomly initialized.") #模型move to device model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters, params是模型的所有参数组成的列表 params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** 开始训练 *****") logger.info(" 样本数是 = %d", len(train_dataset)) logger.info(" 前向 batch size = %d", forward_batch_size) logger.info(" 训练的steps = %d", num_train_steps) ########### 训练的配置 ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型 distiller = BasicTrainer(train_config = train_config, model = model_S, adaptor = BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError #训练的dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) #执行callbakc函数,对eval数据集 callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: #开始训练 distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print(res)
logging.info('Init pre-train model...') if params.pre_model_type == 'NEZHA': bert_config = NEZHAConfig.from_json_file( os.path.join(params.bert_model_dir, 'bert_config.json')) model = BertForTokenClassification(config=bert_config, params=params) # NEZHA init torch_init_model( model, os.path.join(params.bert_model_dir, 'pytorch_model.bin')) elif params.pre_model_type == 'RoBERTa': bert_config = BertConfig.from_json_file( os.path.join(params.bert_model_dir, 'bert_config.json')) model = BertForTokenClassification.from_pretrained( config=bert_config, pretrained_model_name_or_path=params.bert_model_dir, params=params) elif params.pre_model_type == 'ELECTRA': bert_config = ElectraConfig.from_json_file( os.path.join(params.bert_model_dir, 'bert_config.json')) model = ElectraForTokenClassification.from_pretrained( config=bert_config, pretrained_model_name_or_path=params.bert_model_dir, params=params) else: raise ValueError( 'Pre-train Model type must be NEZHA or ELECTRA or RoBERTa!') logging.info('-done') # Train and evaluate the model logging.info("Starting training for {} epoch(s)".format(args.epoch_num)) train_and_evaluate(model, params, args.restore_file)
tokenizer = AlbertTokenizer(vocab_file=args.tokenizer) config = AlbertConfig.from_json_file(args.config) model = AlbertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=torch.load(args.model)) elif 'bert' in args.model: model_type = 'bert' tokenizer = BertTokenizer(vocab_file=args.tokenizer) config = BertConfig.from_json_file(args.config) model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=torch.load(args.model)) elif 'electra' in args.model: model_type = 'electra' tokenizer = ElectraTokenizer(vocab_file=args.tokenizer) config = ElectraConfig.from_json_file(args.config) model = ElectraModel.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=torch.load(args.model)) else: raise NotImplementedError("The model is currently not supported") def process_line(line): data = json.loads(line) tokens = data['text'].split(' ') labels = data['targets'] return tokens, labels def retokenize(tokens_labels): tokens, labels = tokens_labels
def main(): #parse arguments config.parse() args = config.args logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger = logging.getLogger("Main") #arguments check device, n_gpu = args_check(logger, args) if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) if args.local_rank != -1: logger.warning( f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}" ) for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #从transformers包中导入ElectraConfig, 并加载配置 bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') #num_labels;类别个数 bert_config_S.num_labels = len(label2id_dict) assert args.max_seq_length <= bert_config_S.max_position_embeddings #读取数据 train_examples = None train_dataset = None eval_examples = None eval_dataset = None num_train_steps = None # 加载Tokenizer tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.local_rank not in [-1, 0]: torch.distributed.barrier() if args.do_train: # 返回所有的样本和样本形成的dataset格式, dataset的格式[all_token_ids,all_input_mask,all_label_ids] print("开始加载训练集数据") train_examples, train_dataset = read_features( args.train_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.do_predict: print("开始加载测试集数据") eval_examples, eval_dataset = read_features( args.predict_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.local_rank == 0: torch.distributed.barrier() #Build Model and load checkpoint model_S = ElectraForTokenClassification(bert_config_S) #加载student模型的参数, 默认是bert类型 if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") #模型放到device model_S.to(device) if args.do_train: #parameters if args.lr_decay is not None: # 分类器层的参数 weight, bias outputs_params = list(model_S.classifier.named_parameters()) #拆分出做学习率衰减的参数和不衰减的参数 outputs_params = divide_parameters(outputs_params, lr=args.learning_rate) electra_params = [] # eg: 12, encoder层共12层 n_layers = len(model_S.electra.encoder.layer) assert n_layers == 12 for i, n in enumerate(reversed(range(n_layers))): encoder_params = list( model_S.electra.encoder.layer[n].named_parameters()) lr = args.learning_rate * args.lr_decay**(i + 1) electra_params += divide_parameters(encoder_params, lr=lr) logger.info(f"{i}:第{n}层的学习率是:{lr}") embed_params = [ (name, value) for name, value in model_S.electra.named_parameters() if 'embedding' in name ] logger.info(f"{[name for name,value in embed_params]}") lr = args.learning_rate * args.lr_decay**(n_layers + 1) electra_params += divide_parameters(embed_params, lr=lr) logger.info(f"embedding层的学习率 lr:{lr}") all_trainable_params = outputs_params + electra_params assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\ (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters()))) else: params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("可训练的参数all_trainable_params共有: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: train_sampler = DistributedSampler(train_dataset) #生成dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) # 根据epoch计算出运行多少steps num_train_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias=False) if args.official_schedule == 'const': # 常数学习率 scheduler_class = get_constant_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps) } #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps)) elif args.official_schedule == 'linear': # 线性学习率 scheduler_class = get_linear_schedule_with_warmup # warmup多少步,10%的step是warmup的 scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps) elif args.official_schedule == 'const_nowarmup': scheduler_class = get_constant_schedule scheduler_args = {} else: raise NotImplementedError logger.warning("***** 开始 训练 *****") logger.warning("local_rank %d 原样本数 = %d", args.local_rank, len(train_examples)) logger.warning("local_rank %d split之后的样本数 = %d", args.local_rank, len(train_dataset)) logger.warning("local_rank %d 前向 batch size = %d", args.local_rank, forward_batch_size) logger.warning("local_rank %d 训练的steps = %d", args.local_rank, num_train_steps) ########### TRAINING ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, #保存频率 log_dir=args.output_dir, output_dir=args.output_dir, device=args.device, fp16=args.fp16, local_rank=args.local_rank) logger.info(f"训练的配置文件:") logger.info(f"{train_config}") # 初始化训练器 distiller = BasicTrainer( train_config=train_config, model=model_S, adaptor=ElectraForTokenClassificationAdaptorTraining) # 初始化callback函数,使用的ddp_predict函数进行评估 callback_func = partial(ddp_predict, eval_examples=eval_examples, eval_dataset=eval_dataset, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, max_grad_norm=1.0, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = ddp_predict(model_S, eval_examples, eval_dataset, step=0, args=args) print(res)