def load_model(self): parser = argparse.ArgumentParser() args = parser.parse_args() args.output_encoded_layers = True args.output_attention_layers = True args.output_att_score = True args.output_att_sum = True self.args = args # 解析配置文件, 教师模型和student模型的vocab是不变的 self.vocab_file = "bert_model/vocab.txt" # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型 # student config: config/chinese_bert_config_L4t.json # distil student model: distil_model/gs8316.pkl self.bert_config_file_S = "bert_model/config.json" self.tuned_checkpoint_S = "trained_teacher_model/gs3024.pkl" self.max_seq_length = 70 # 预测的batch_size大小 self.predict_batch_size = 64 # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度 bert_config_S = BertConfig.from_json_file(self.bert_config_file_S) # 加载tokenizer tokenizer = BertTokenizer(vocab_file=self.vocab_file) # 加载模型 model_S = BertSPCSimple(bert_config_S, num_labels=self.num_labels, args=self.args) state_dict_S = torch.load(self.tuned_checkpoint_S, map_location=self.device) model_S.load_state_dict(state_dict_S) if self.verbose: print("模型已加载") return tokenizer, model_S
def load_macbert_model(self): parser = argparse.ArgumentParser() args = parser.parse_args() args.output_encoded_layers = True args.output_attention_layers = True args.output_att_score = True args.output_att_sum = True self.args = args # 解析配置文件, 教师模型和student模型的vocab是不变的 self.vocab_file = "mac_bert_model/vocab.txt" # 这里是使用的teacher的config和微调后的teacher模型, 也可以换成student的config和蒸馏后的student模型 # student config: config/chinese_bert_config_L4t.json # distil student model: distil_model/gs8316.pkl self.bert_config_file_S = "mac_bert_model/config.json" self.tuned_checkpoint_S = "trained_teacher_model/macbert_2290_cosmetics_weibo.pkl" # self.tuned_checkpoint_S = "trained_teacher_model/macbert_894_cosmetics.pkl" # self.tuned_checkpoint_S = "trained_teacher_model/macbert_teacher_max75len_5000.pkl" # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度 bert_config_S = BertConfig.from_json_file(self.bert_config_file_S) # 加载tokenizer tokenizer = BertTokenizer(vocab_file=self.vocab_file) # 加载模型 model_S = BertSPCSimple(bert_config_S, num_labels=self.num_labels, args=self.args) state_dict_S = torch.load(self.tuned_checkpoint_S, map_location=self.device) model_S.load_state_dict(state_dict_S) if self.verbose: print("模型已加载") self.predict_tokenizer = tokenizer self.predict_model = model_S logger.info(f"macbert预测模型加载完成")
def load_train_model(self): """ 初始化训练的模型 :return: """ parser = argparse.ArgumentParser() args = parser.parse_args() args.output_encoded_layers = True args.output_attention_layers = True args.output_att_score = True args.output_att_sum = True self.learning_rate = 2e-05 #学习率 warmup的比例 self.warmup_proportion = 0.1 self.num_train_epochs = 1 #使用的学习率scheduler self.schedule = 'slanted_triangular' self.s_opt1 = 30.0 self.s_opt2 = 0.0 self.s_opt3 = 1.0 self.weight_decay_rate = 0.01 #训练多少epcoh保存一次模型 self.ckpt_frequency = 1 #模型和日志保存的位置 self.output_dir = "output_root_dir/train_api" #梯度累积步数 self.gradient_accumulation_steps = 1 self.args = args # 解析配置文件, 教师模型和student模型的vocab是不变的 self.vocab_file = "mac_bert_model/vocab.txt" self.bert_config_file_S = "mac_bert_model/config.json" self.tuned_checkpoint_S = "mac_bert_model/pytorch_model.bin" # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度 bert_config_S = BertConfig.from_json_file(self.bert_config_file_S) # 加载tokenizer tokenizer = BertTokenizer(vocab_file=self.vocab_file) # 加载模型 model_S = BertSPCSimple(bert_config_S, num_labels=self.num_labels, args=self.args) state_dict_S = torch.load(self.tuned_checkpoint_S, map_location=self.device) state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) #验证下参数没有丢失 assert len(missing_keys) == 0 self.train_tokenizer = tokenizer self.train_model = model_S logger.info(f"训练模型{self.tuned_checkpoint_S}加载完成")
def main(): # 解析参数 config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") # 解析参数, 判断使用的设备 device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size # 加载student的配置文件, 校验最大序列长度小于我们的配置中的序列长度 bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings # 准备task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # 所有的labels label_list = processor.get_labels() num_labels = len(label_list) # 读取数据 tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) eval_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=True) logger.info("评估数据集已加载") model_S = BertSPCSimple(bert_config_S, num_labels=num_labels, args=args) # 加载student模型 assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Student模型已加载") # 开始预测 res = predict(model_S, eval_dataset, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_datasets.append( load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("数据集已加载") #加载模型并初始化, 只用student模型,其实这里相当于在MNLI数据上训练教师模型,只训练一个模型 model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #初始化student模型 if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### 蒸馏 ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) #执行监督训练,而不是蒸馏。它可以用于训练teacher模型。初始化模型 distiller = BasicTrainer(train_config=train_config, model=model_S, adaptor=BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集 if args.do_train: train_dataset, examples = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset, examples = load_and_cache_examples( args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_dataset, examples = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_datasets.append(eval_dataset) logger.info("数据集加载成功") #加载模型,加载teacher和student模型 model_T = BertForGLUESimple(bert_config_T, num_labels=num_labels, args=args) model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #加载teacher模型参数 if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.") model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1) model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) #优化器配置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) # 定义了一些固定的matches配置文件 from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"中间层match信息: {intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) logger.info(f"训练配置: {train_config}") logger.info(f"蒸馏配置: {distill_config}") adaptor_T = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) adaptor_S = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) # 支持中间状态匹配的通用蒸馏模型 distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args, examples=examples) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args, examples=examples, label_list=label_list) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_features = None eval_examples = None eval_features = None num_train_steps = None tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) convert_fn = partial(convert_examples_to_features, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length) if args.do_train: train_examples, train_features = read_and_convert( args.train_file, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) if args.fake_file_1: fake_examples1, fake_features1 = read_and_convert( args.fake_file_1, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples1 train_features += fake_features1 if args.fake_file_2: fake_examples2, fake_features2 = read_and_convert( args.fake_file_2, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples2 train_features += fake_features2 num_train_steps = int(len(train_features) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_examples, eval_features = read_and_convert( args.predict_file, is_training=False, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) #Build Model and load checkpoint model_S = BertForQASimple(bert_config_S, args) #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) distiller = BasicTrainer(train_config=train_config, model=model_S, adaptor=BertForQASimpleAdaptorTraining) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_examples=eval_examples, eval_features=eval_features, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_examples, eval_features, step=0, args=args) print(res)
def main(): #解析参数 config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #准备任务 processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) # Student的配置 if args.model_architecture == "electra": # 从transformers包中导入ElectraConfig, 并加载配置 bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings elif args.model_architecture == "albert": # 从transformers包中导入AlbertConfig, 并加载配置 bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings else: bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_dataset = None eval_datasets = None num_train_steps = None # electra和bert都使用的bert的 tokenizer方式 tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs logger.info("训练数据集已加载") if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("预测数据集已加载") # Student的配置 if args.model_architecture == "electra": #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = ElectraSPC(bert_config_S) elif args.model_architecture == "albert": model_S = AlbertSPC(bert_config_S) else: #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args) #对加载后的student模型的参数进行初始化, 使用student模型预测 if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) print(f"missing_keys,注意丢失的参数{missing_keys}") logger.info("Model loaded") elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") elif args.model_architecture in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") else: logger.info("Model is randomly initialized.") #模型move to device model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters, params是模型的所有参数组成的列表 params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** 开始训练 *****") logger.info(" 样本数是 = %d", len(train_dataset)) logger.info(" 前向 batch size = %d", forward_batch_size) logger.info(" 训练的steps = %d", num_train_steps) ########### 训练的配置 ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型 distiller = BasicTrainer(train_config = train_config, model = model_S, adaptor = BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError #训练的dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) #执行callbakc函数,对eval数据集 callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: #开始训练 distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args) #Load teacher if args.tuned_checkpoint_Ts: model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))] for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts): logger.info("Load state dict %s" % ckpt_T) state_dict_T = torch.load(ckpt_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type=='bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) assert len(missing_keys)==0 elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") if args.do_train: for model_T in model_Ts: model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) distill_config = DistillationConfig( temperature = args.temperature, kd_loss_type = 'ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False) distiller = MultiTeacherDistiller(train_config = train_config, distill_config = distill_config, model_T = model_Ts, model_S = model_S, adaptor_T=adaptor, adaptor_S=adaptor) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print (res)