def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集 if args.do_train: train_dataset, examples = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset, examples = load_and_cache_examples( args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_dataset, examples = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_datasets.append(eval_dataset) logger.info("数据集加载成功") #加载模型,加载teacher和student模型 model_T = BertForGLUESimple(bert_config_T, num_labels=num_labels, args=args) model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #加载teacher模型参数 if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.") model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1) model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) #优化器配置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) # 定义了一些固定的matches配置文件 from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"中间层match信息: {intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) logger.info(f"训练配置: {train_config}") logger.info(f"蒸馏配置: {distill_config}") adaptor_T = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) adaptor_S = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) # 支持中间状态匹配的通用蒸馏模型 distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args, examples=examples) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args, examples=examples, label_list=label_list) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] # eg: MNLI,['contradiction', 'entailment', 'neutral'] --> [“矛盾”,“必然”,“中立”] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集, 计算steps if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_datasets.append( load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("数据集已加载") #加载模型并初始化, 只用student模型,其实这里相当于在MNLI数据上训练教师模型,只训练一个模型 model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #初始化student模型 if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) # 优化器设置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### 蒸馏 ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) #执行监督训练,而不是蒸馏。它可以用于训练teacher模型。初始化模型 distiller = BasicTrainer(train_config=train_config, model=model_S, adaptor=BertForGLUESimpleAdaptorTraining) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args) #Load teacher if args.tuned_checkpoint_Ts: model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))] for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts): logger.info("Load state dict %s" % ckpt_T) state_dict_T = torch.load(ckpt_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type=='bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) assert len(missing_keys)==0 elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") if args.do_train: for model_T in model_Ts: model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) distill_config = DistillationConfig( temperature = args.temperature, kd_loss_type = 'ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False) distiller = MultiTeacherDistiller(train_config = train_config, distill_config = distill_config, model_T = model_Ts, model_S = model_S, adaptor_T=adaptor, adaptor_S=adaptor) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print (res)