def do_train(self, data, truncated=False): """ 训练模型, 数据集分成2部分,训练集和验证集, 默认比例9:1 :param data: 输入的数据,注意如果做truncated,那么输入的数据为 [(content,aspect,start_idx, end_idx, label),...,] :param truncated: 是否要截断,截断按照 self.left_max_seq_len, self.right_max_seq_len进行 :return: """ if truncated: data, locations = self.do_truncate_data(data) train_data_len = int(len(data) * 0.9) train_data = data[:train_data_len] eval_data = data[train_data_len:] train_dataset = load_examples(train_data, self.max_seq_length, self.train_tokenizer, self.label_list, self.reverse_truncate) eval_dataset = load_examples(eval_data, self.max_seq_length, self.train_tokenizer, self.label_list, self.reverse_truncate) logger.info("训练数据集已加载,开始训练") num_train_steps = int(len(train_dataset) / self.train_batch_size) * self.num_train_epochs forward_batch_size = int(self.train_batch_size / self.gradient_accumulation_steps) # 开始训练 params = list(self.train_model.named_parameters()) all_trainable_params = divide_parameters(params, lr=self.learning_rate, weight_decay_rate=self.weight_decay_rate) # 优化器设置 optimizer = BERTAdam(all_trainable_params, lr=self.learning_rate, warmup=self.warmup_proportion, t_total=num_train_steps, schedule=self.schedule, s_opt1=self.s_opt1, s_opt2=self.s_opt2, s_opt3=self.s_opt3) logger.info("***** 开始训练 *****") logger.info(" 训练样本数是 = %d", len(train_dataset)) logger.info(" 评估样本数是 = %d", len(eval_dataset)) logger.info(" 前向 batch size = %d", forward_batch_size) logger.info(" 训练的steps = %d", num_train_steps) ########### 训练的配置 ########### train_config = TrainingConfig( gradient_accumulation_steps = self.gradient_accumulation_steps, ckpt_frequency = self.ckpt_frequency, log_dir = self.output_dir, output_dir = self.output_dir, device = self.device) #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型 distiller = BasicTrainer(train_config = train_config, model = self.train_model, adaptor = BertForGLUESimpleAdaptorTraining) train_sampler = RandomSampler(train_dataset) #训练的dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=forward_batch_size, drop_last=True) #执行callbakc函数,对eval数据集 callback_func = partial(self.do_predict, eval_dataset=eval_dataset) with distiller: #开始训练 distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=self.num_train_epochs, callback=callback_func) logger.info(f"训练完成") return "Done"
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_datasets.append( load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("数据集已加载") #加载模型并初始化, 只用student模型,其实这里相当于在MNLI数据上训练教师模型,只训练一个模型 model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #初始化student模型 if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### 蒸馏 ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) #执行监督训练,而不是蒸馏。它可以用于训练teacher模型。初始化模型 distiller = BasicTrainer(train_config=train_config, model=model_S, adaptor=BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集 if args.do_train: train_dataset, examples = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset, examples = load_and_cache_examples( args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_dataset, examples = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_datasets.append(eval_dataset) logger.info("数据集加载成功") #加载模型,加载teacher和student模型 model_T = BertForGLUESimple(bert_config_T, num_labels=num_labels, args=args) model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #加载teacher模型参数 if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.") model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1) model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) #优化器配置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) # 定义了一些固定的matches配置文件 from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"中间层match信息: {intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) logger.info(f"训练配置: {train_config}") logger.info(f"蒸馏配置: {distill_config}") adaptor_T = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) adaptor_S = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) # 支持中间状态匹配的通用蒸馏模型 distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args, examples=examples) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args, examples=examples, label_list=label_list) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_features = None eval_examples = None eval_features = None num_train_steps = None tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) convert_fn = partial(convert_examples_to_features, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length) if args.do_train: train_examples, train_features = read_and_convert( args.train_file, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) if args.fake_file_1: fake_examples1, fake_features1 = read_and_convert( args.fake_file_1, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples1 train_features += fake_features1 if args.fake_file_2: fake_examples2, fake_features2 = read_and_convert( args.fake_file_2, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples2 train_features += fake_features2 num_train_steps = int(len(train_features) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_examples, eval_features = read_and_convert( args.predict_file, is_training=False, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) #Build Model and load checkpoint model_S = BertForQASimple(bert_config_S, args) #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) distiller = BasicTrainer(train_config=train_config, model=model_S, adaptor=BertForQASimpleAdaptorTraining) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_examples=eval_examples, eval_features=eval_features, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_examples, eval_features, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load config teachers_and_student = parse_model_config(args.model_config_json) #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer_S = teachers_and_student['student']['tokenizer'] prefix_S = teachers_and_student['student']['prefix'] if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer_S, prefix=prefix_S, evaluate=False) if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_datasets.append( load_and_cache_examples(args, eval_task, tokenizer_S, prefix=prefix_S, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint if args.do_train: model_Ts = [] for teacher in teachers_and_student['teachers']: model_type_T = teacher['model_type'] model_config_T = teacher['config'] checkpoint_T = teacher['checkpoint'] _, _, model_class_T = MODEL_CLASSES[model_type_T] model_T = model_class_T(model_config_T, num_labels=num_labels) state_dict_T = torch.load(checkpoint_T, map_location='cpu') missing_keys, un_keys = model_T.load_state_dict(state_dict_T, strict=True) logger.info(f"Teacher Model {model_type_T} loaded") model_T.to(device) model_T.eval() model_Ts.append(model_T) student = teachers_and_student['student'] model_type_S = student['model_type'] model_config_S = student['config'] checkpoint_S = student['checkpoint'] _, _, model_class_S = MODEL_CLASSES[model_type_S] model_S = model_class_S(model_config_S, num_labels=num_labels) if checkpoint_S is not None: state_dict_S = torch.load(checkpoint_S, map_location='cpu') missing_keys, un_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{un_keys}") else: logger.warning("Initializing student randomly") logger.info("Student Model loaded") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [ torch.nn.DataParallel(model_T) for model_T in model_Ts ] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) num_train_steps = int( len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs) ########## DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, fp16=args.fp16, device=args.device) distill_config = DistillationConfig(temperature=args.temperature, kd_loss_type='ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor_T = BertForGLUESimpleAdaptor adaptor_S = BertForGLUESimpleAdaptor distiller = MultiTeacherDistiller(train_config=train_config, distill_config=distill_config, model_T=model_Ts, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) optimizer = AdamW(all_trainable_params, lr=args.learning_rate) scheduler_class = get_linear_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func, max_grad_norm=1) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args) print(res)
def train(args, train_dataset, model_T, model, tokenizer, labels, pad_token_label_id, predict_callback): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [{ "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(args.warmup_steps * t_total), num_training_steps=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) if args.do_train and args.do_distill: from textbrewer import DistillationConfig, TrainingConfig, GeneralDistiller distill_config = DistillationConfig( temperature=8, # intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1}] ) train_config = TrainingConfig(device="cuda", log_dir=args.output_dir, output_dir=args.output_dir) def adaptor_T(batch, model_output): return { "logits": (model_output[1], ), 'logits_mask': (batch['attention_mask'], ) } def adaptor_S(batch, model_output): return { "logits": (model_output[1], ), 'logits_mask': (batch['attention_mask'], ) } distiller = GeneralDistiller( train_config, distill_config, model_T, model, adaptor_T, adaptor_S, ) distiller.train(optimizer, scheduler, train_dataloader, args.num_train_epochs, callback=predict_callback) return
label_ids.append(labels[i]) model.train() pred_logits = np.array(pred_logits) label_ids = np.array(label_ids) y_p = pred_logits.argmax(axis=-1) accuracy = (y_p == label_ids).sum() / len(label_ids) print("Number of examples: ", len(y_p)) print("Acc: ", accuracy) from functools import partial callback_fun = partial(predict, eval_dataset=eval_dataset, device=device) # fill other arguments # Initialize configurations and distiller train_config = TrainingConfig(device=device) distill_config = DistillationConfig(temperature=8, hard_label_weight=0, kd_loss_type='ce', probability_shift=False, intermediate_matches=[{ 'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1 }, { 'layer_T': 8, 'layer_S': 2, 'feature': 'hidden', 'loss': 'hidden_mse',
def main(): #parse arguments config.parse() args = config.args logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger = logging.getLogger("Main") #arguments check device, n_gpu = args_check(logger, args) if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) if args.local_rank != -1: logger.warning( f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}" ) for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.num_labels = len(label2id_dict) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_dataset = None eval_examples = None eval_dataset = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.local_rank not in [-1, 0]: torch.distributed.barrier() if args.do_train: train_examples, train_dataset = read_features( args.train_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.do_predict: eval_examples, eval_dataset = read_features( args.predict_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.local_rank == 0: torch.distributed.barrier() #Build Model and load checkpoint model_S = ElectraForTokenClassification(bert_config_S) #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_S.to(device) if args.do_train: #parameters if args.lr_decay is not None: outputs_params = list(model_S.classifier.named_parameters()) outputs_params = divide_parameters(outputs_params, lr=args.learning_rate) electra_params = [] n_layers = len(model_S.electra.encoder.layer) assert n_layers == 12 for i, n in enumerate(reversed(range(n_layers))): encoder_params = list( model_S.electra.encoder.layer[n].named_parameters()) lr = args.learning_rate * args.lr_decay**(i + 1) electra_params += divide_parameters(encoder_params, lr=lr) logger.info(f"{i},{n},{lr}") embed_params = [ (name, value) for name, value in model_S.electra.named_parameters() if 'embedding' in name ] logger.info(f"{[name for name,value in embed_params]}") lr = args.learning_rate * args.lr_decay**(n_layers + 1) electra_params += divide_parameters(embed_params, lr=lr) logger.info(f"embed lr:{lr}") all_trainable_params = outputs_params + electra_params assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\ (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters()))) else: params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) num_train_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias=False) if args.official_schedule == 'const': scheduler_class = get_constant_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps) } #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps)) elif args.official_schedule == 'linear': scheduler_class = get_linear_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps) elif args.official_schedule == 'const_nowarmup': scheduler_class = get_constant_schedule scheduler_args = {} else: raise NotImplementedError logger.warning("***** Running training *****") logger.warning("local_rank %d Num orig examples = %d", args.local_rank, len(train_examples)) logger.warning("local_rank %d Num split examples = %d", args.local_rank, len(train_dataset)) logger.warning("local_rank %d Forward batch size = %d", args.local_rank, forward_batch_size) logger.warning("local_rank %d Num backward steps = %d", args.local_rank, num_train_steps) ########### TRAINING ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device, fp16=args.fp16, local_rank=args.local_rank) logger.info(f"{train_config}") distiller = BasicTrainer( train_config=train_config, model=model_S, adaptor=ElectraForTokenClassificationAdaptorTraining) # evluate the model in a single process in ddp_predict callback_func = partial(ddp_predict, eval_examples=eval_examples, eval_dataset=eval_dataset, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, max_grad_norm=1.0, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = ddp_predict(model_S, eval_examples, eval_dataset, step=0, args=args) print(res)
from textbrewer import GeneralDistiller from textbrewer import TrainingConfig, DistillationConfig # We omit the initialization of models, optimizer, and dataloader. teacher_model : torch.nn.Module = ... student_model : torch.nn.Module = ... dataloader : torch.utils.data.DataLoader = ... optimizer : torch.optim.Optimizer = ... scheduler : torch.optim.lr_scheduler = ... def simple_adaptor(batch, model_outputs): # We assume that the first element of model_outputs # is the logits before softmax return {'logits': model_outputs[0]} train_config = TrainingConfig() distill_config = DistillationConfig() distiller = GeneralDistiller( train_config=train_config, distill_config = distill_config, model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor) distiller.train(optimizer, scheduler, dataloader, num_epochs, callback=None) def predict(model, eval_dataset, step, args): raise NotImplementedError
def main(): #解析参数 config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #准备任务 processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) # Student的配置 if args.model_architecture == "electra": # 从transformers包中导入ElectraConfig, 并加载配置 bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings elif args.model_architecture == "albert": # 从transformers包中导入AlbertConfig, 并加载配置 bert_config_S = AlbertConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.output_attentions = (args.output_attention_layers=='true') # num_labels;类别个数 bert_config_S.num_labels = num_labels assert args.max_seq_length <= bert_config_S.max_position_embeddings else: bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_dataset = None eval_datasets = None num_train_steps = None # electra和bert都使用的bert的 tokenizer方式 tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs logger.info("训练数据集已加载") if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("预测数据集已加载") # Student的配置 if args.model_architecture == "electra": #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = ElectraSPC(bert_config_S) elif args.model_architecture == "albert": model_S = AlbertSPC(bert_config_S) else: #加载模型配置, 只用student模型,其实这里相当于训练教师模型,只训练一个模型 model_S = BertSPCSimple(bert_config_S, num_labels=num_labels,args=args) #对加载后的student模型的参数进行初始化, 使用student模型预测 if args.load_model_type=='bert' and args.model_architecture not in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.embeddings')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) print(f"missing_keys,注意丢失的参数{missing_keys}") logger.info("Model loaded") elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") elif args.model_architecture in ["electra", "albert"]: assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") else: logger.info("Model is randomly initialized.") #模型move to device model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters, params是模型的所有参数组成的列表 params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("要训练的模型参数量组是,包括decay_group和no_decay_group: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** 开始训练 *****") logger.info(" 样本数是 = %d", len(train_dataset)) logger.info(" 前向 batch size = %d", forward_batch_size) logger.info(" 训练的steps = %d", num_train_steps) ########### 训练的配置 ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) #初始化trainer,执行监督训练,而不是蒸馏。它可以把model_S模型训练成为teacher模型 distiller = BasicTrainer(train_config = train_config, model = model_S, adaptor = BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError #训练的dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) #执行callbakc函数,对eval数据集 callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: #开始训练 distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print(res)
def main(): config_distil.parse() global args args = config_distil.args global logger logger = create_logger(args.log_file) for k, v in vars(args).items(): logger.info(f"{k}:{v}") # set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) logger.info('加载字典') word2idx = load_chinese_base_vocab() # 判断是否有可用GPU args.device = torch.device( "cuda" if torch.cuda.is_available() and args.is_cuda else "cpu") logger.info('using device:{}'.format(args.device)) # 定义模型超参数 bertconfig_T = BertConfig(vocab_size=len(word2idx)) bertconfig_S = BertConfig(vocab_size=len(word2idx), num_hidden_layers=3) logger.info('初始化BERT模型') bert_model_T = Seq2SeqModel(config=bertconfig_T) bert_model_S = Seq2SeqModel(config=bertconfig_S) logger.info('加载Teacher模型~') load_model(bert_model_T, args.tuned_checkpoint_T) logger.info('将模型发送到计算设备(GPU或CPU)') bert_model_T.to(args.device) bert_model_T.eval() logger.info('加载Student模型~') if args.load_model_type == 'bert': load_model(bert_model_S, args.init_checkpoint_S) else: logger.info(" Student Model is randomly initialized.") logger.info('将模型发送到计算设备(GPU或CPU)') bert_model_S.to(args.device) # 声明自定义的数据加载器 logger.info('加载训练数据') train = SelfDataset(args.train_path, args.max_length) trainloader = DataLoader(train, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn) if args.do_train: logger.info(' 声明需要优化的参数') num_train_steps = int( len(trainloader) / args.train_batch_size) * args.num_train_epochs optim_parameters = list(bert_model_S.named_parameters()) all_trainable_params = divide_parameters(optim_parameters, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) from generative.matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"{intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) def BertForS2SSimpleAdaptor(batch, model_outputs): return {'hidden': model_outputs[0], 'logits': model_outputs[1], 'loss': model_outputs[2], 'attention': model_outputs[3]} adaptor_T = partial(BertForS2SSimpleAdaptor) adaptor_S = partial(BertForS2SSimpleAdaptor) distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=bert_model_T, model_S=bert_model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) callback_func = partial(predict, data_path=args.dev_path, args=args) logger.info('Start distillation.') with distiller: distiller.train(optimizer, scheduler=None, dataloader=trainloader, num_epochs=args.num_train_epochs, callback=None) if not args.do_train and args.do_predict: res = predict(bert_model_S, args.test_path, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args) #Load teacher if args.tuned_checkpoint_Ts: model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))] for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts): logger.info("Load state dict %s" % ckpt_T) state_dict_T = torch.load(ckpt_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type=='bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) assert len(missing_keys)==0 elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") if args.do_train: for model_T in model_Ts: model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) distill_config = DistillationConfig( temperature = args.temperature, kd_loss_type = 'ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False) distiller = MultiTeacherDistiller(train_config = train_config, distill_config = distill_config, model_T = model_Ts, model_S = model_S, adaptor_T=adaptor, adaptor_S=adaptor) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print (res)
def main(): #parse arguments config.parse() args = config.args logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger = logging.getLogger("Main") #arguments check device, n_gpu = args_check(logger, args) if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) if args.local_rank != -1: logger.warning( f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}" ) for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #从transformers包中导入ElectraConfig, 并加载配置 bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) # (args.output_encoded_layers=='true') --> True, 默认输出隐藏层的状态 bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') #num_labels;类别个数 bert_config_S.num_labels = len(label2id_dict) assert args.max_seq_length <= bert_config_S.max_position_embeddings #读取数据 train_examples = None train_dataset = None eval_examples = None eval_dataset = None num_train_steps = None # 加载Tokenizer tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.local_rank not in [-1, 0]: torch.distributed.barrier() if args.do_train: # 返回所有的样本和样本形成的dataset格式, dataset的格式[all_token_ids,all_input_mask,all_label_ids] print("开始加载训练集数据") train_examples, train_dataset = read_features( args.train_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.do_predict: print("开始加载测试集数据") eval_examples, eval_dataset = read_features( args.predict_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length) if args.local_rank == 0: torch.distributed.barrier() #Build Model and load checkpoint model_S = ElectraForTokenClassification(bert_config_S) #加载student模型的参数, 默认是bert类型 if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') #state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} #missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") #模型放到device model_S.to(device) if args.do_train: #parameters if args.lr_decay is not None: # 分类器层的参数 weight, bias outputs_params = list(model_S.classifier.named_parameters()) #拆分出做学习率衰减的参数和不衰减的参数 outputs_params = divide_parameters(outputs_params, lr=args.learning_rate) electra_params = [] # eg: 12, encoder层共12层 n_layers = len(model_S.electra.encoder.layer) assert n_layers == 12 for i, n in enumerate(reversed(range(n_layers))): encoder_params = list( model_S.electra.encoder.layer[n].named_parameters()) lr = args.learning_rate * args.lr_decay**(i + 1) electra_params += divide_parameters(encoder_params, lr=lr) logger.info(f"{i}:第{n}层的学习率是:{lr}") embed_params = [ (name, value) for name, value in model_S.electra.named_parameters() if 'embedding' in name ] logger.info(f"{[name for name,value in embed_params]}") lr = args.learning_rate * args.lr_decay**(n_layers + 1) electra_params += divide_parameters(embed_params, lr=lr) logger.info(f"embedding层的学习率 lr:{lr}") all_trainable_params = outputs_params + electra_params assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\ (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters()))) else: params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("可训练的参数all_trainable_params共有: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: train_sampler = DistributedSampler(train_dataset) #生成dataloader train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) # 根据epoch计算出运行多少steps num_train_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias=False) if args.official_schedule == 'const': # 常数学习率 scheduler_class = get_constant_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps) } #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps)) elif args.official_schedule == 'linear': # 线性学习率 scheduler_class = get_linear_schedule_with_warmup # warmup多少步,10%的step是warmup的 scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps) elif args.official_schedule == 'const_nowarmup': scheduler_class = get_constant_schedule scheduler_args = {} else: raise NotImplementedError logger.warning("***** 开始 训练 *****") logger.warning("local_rank %d 原样本数 = %d", args.local_rank, len(train_examples)) logger.warning("local_rank %d split之后的样本数 = %d", args.local_rank, len(train_dataset)) logger.warning("local_rank %d 前向 batch size = %d", args.local_rank, forward_batch_size) logger.warning("local_rank %d 训练的steps = %d", args.local_rank, num_train_steps) ########### TRAINING ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, #保存频率 log_dir=args.output_dir, output_dir=args.output_dir, device=args.device, fp16=args.fp16, local_rank=args.local_rank) logger.info(f"训练的配置文件:") logger.info(f"{train_config}") # 初始化训练器 distiller = BasicTrainer( train_config=train_config, model=model_S, adaptor=ElectraForTokenClassificationAdaptorTraining) # 初始化callback函数,使用的ddp_predict函数进行评估 callback_func = partial(ddp_predict, eval_examples=eval_examples, eval_dataset=eval_dataset, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, max_grad_norm=1.0, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = ddp_predict(model_S, eval_examples, eval_dataset, step=0, args=args) print(res)