def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load config teachers_and_student = parse_model_config(args.model_config_json) #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer_S = teachers_and_student['student']['tokenizer'] prefix_S = teachers_and_student['student']['prefix'] if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer_S, prefix=prefix_S, evaluate=False) if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_datasets.append( load_and_cache_examples(args, eval_task, tokenizer_S, prefix=prefix_S, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint if args.do_train: model_Ts = [] for teacher in teachers_and_student['teachers']: model_type_T = teacher['model_type'] model_config_T = teacher['config'] checkpoint_T = teacher['checkpoint'] _, _, model_class_T = MODEL_CLASSES[model_type_T] model_T = model_class_T(model_config_T, num_labels=num_labels) state_dict_T = torch.load(checkpoint_T, map_location='cpu') missing_keys, un_keys = model_T.load_state_dict(state_dict_T, strict=True) logger.info(f"Teacher Model {model_type_T} loaded") model_T.to(device) model_T.eval() model_Ts.append(model_T) student = teachers_and_student['student'] model_type_S = student['model_type'] model_config_S = student['config'] checkpoint_S = student['checkpoint'] _, _, model_class_S = MODEL_CLASSES[model_type_S] model_S = model_class_S(model_config_S, num_labels=num_labels) if checkpoint_S is not None: state_dict_S = torch.load(checkpoint_S, map_location='cpu') missing_keys, un_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{un_keys}") else: logger.warning("Initializing student randomly") logger.info("Student Model loaded") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [ torch.nn.DataParallel(model_T) for model_T in model_Ts ] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) num_train_steps = int( len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs) ########## DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, fp16=args.fp16, device=args.device) distill_config = DistillationConfig(temperature=args.temperature, kd_loss_type='ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor_T = BertForGLUESimpleAdaptor adaptor_S = BertForGLUESimpleAdaptor distiller = MultiTeacherDistiller(train_config=train_config, distill_config=distill_config, model_T=model_Ts, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) optimizer = AdamW(all_trainable_params, lr=args.learning_rate) scheduler_class = get_linear_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func, max_grad_norm=1) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) # 加载数据集 if args.do_train: train_dataset, examples = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset, examples = load_and_cache_examples( args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset( [train_dataset, aux_train_dataset]) num_train_steps = int( len(train_dataset) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else ( args.task_name, ) for eval_task in eval_task_names: eval_dataset, examples = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) eval_datasets.append(eval_dataset) logger.info("数据集加载成功") #加载模型,加载teacher和student模型 model_T = BertForGLUESimple(bert_config_T, num_labels=num_labels, args=args) model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels, args=args) #加载teacher模型参数 if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') if args.only_load_embedding: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.embeddings') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) logger.info(f"Missing keys {list(missing_keys)}") else: state_weight = { k[5:]: v for k, v in state_dict_S.items() if k.startswith('bert.') } missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 logger.info("Model loaded") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) logger.info("Model loaded") else: logger.info("Student模型没有可加载参数,随机初始化参数 randomly initialized.") model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1) model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) #优化器配置 optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) # 定义了一些固定的matches配置文件 from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"中间层match信息: {intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) logger.info(f"训练配置: {train_config}") logger.info(f"蒸馏配置: {distill_config}") adaptor_T = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) adaptor_S = partial(BertForGLUESimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) # 支持中间状态匹配的通用蒸馏模型 distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args, examples=examples) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_datasets, step=0, args=args, examples=examples, label_list=label_list) print(res)
def train(args, train_dataset, model_T, model, tokenizer, labels, pad_token_label_id, predict_callback): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [{ "params": [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": args.weight_decay }, { "params": [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(args.warmup_steps * t_total), num_training_steps=t_total) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) if args.do_train and args.do_distill: from textbrewer import DistillationConfig, TrainingConfig, GeneralDistiller distill_config = DistillationConfig( temperature=8, # intermediate_matches = [{'layer_T':10, 'layer_S':3, 'feature':'hidden','loss': 'hidden_mse', 'weight' : 1}] ) train_config = TrainingConfig(device="cuda", log_dir=args.output_dir, output_dir=args.output_dir) def adaptor_T(batch, model_output): return { "logits": (model_output[1], ), 'logits_mask': (batch['attention_mask'], ) } def adaptor_S(batch, model_output): return { "logits": (model_output[1], ), 'logits_mask': (batch['attention_mask'], ) } distiller = GeneralDistiller( train_config, distill_config, model_T, model, adaptor_T, adaptor_S, ) distiller.train(optimizer, scheduler, train_dataloader, args.num_train_epochs, callback=predict_callback) return
def main(): #parse arguments config.parse() args = config.args for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_features = None eval_examples = None eval_features = None num_train_steps = None tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) convert_fn = partial(convert_examples_to_features, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length) if args.do_train: train_examples, train_features = read_and_convert( args.train_file, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) if args.fake_file_1: fake_examples1, fake_features1 = read_and_convert( args.fake_file_1, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples1 train_features += fake_features1 if args.fake_file_2: fake_examples2, fake_features2 = read_and_convert( args.fake_file_2, is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) train_examples += fake_examples2 train_features += fake_features2 num_train_steps = int(len(train_features) / args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_examples, eval_features = read_and_convert( args.predict_file, is_training=False, do_lower_case=args.do_lower_case, read_fn=read_squad_examples, convert_fn=convert_fn) #Build Model and load checkpoint model_T = BertForQASimple(bert_config_T, args) model_S = BertForQASimple(bert_config_S, args) #Load teacher if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') # state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} state_weight = {k: v for k, v in state_dict_S.items()} missing_keys, _ = model_S.bert.load_state_dict(state_weight, strict=False) assert len(missing_keys) == 0 elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_T = torch.nn.DataParallel(model_T) #,output_device=n_gpu-1) model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"{intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) adaptor_T = partial(BertForQASimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) adaptor_S = partial(BertForQASimpleAdaptor, no_logits=args.no_logits, no_mask=args.no_inputs_mask) distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor( [f.end_position for f in train_features], dtype=torch.long) train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) callback_func = partial(predict, eval_examples=eval_examples, eval_features=eval_features, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S, eval_examples, eval_features, step=0, args=args) print(res)
train_config = TrainingConfig(device=device) distill_config = DistillationConfig(temperature=8, hard_label_weight=0, kd_loss_type='ce', probability_shift=False, intermediate_matches=[{ 'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1 }, { 'layer_T': 8, 'layer_S': 2, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1 }, { 'layer_T': [0, 0], 'layer_S': [0, 0], 'feature': 'hidden', 'loss': 'nst', 'weight': 1 }, { 'layer_T': [8, 8], 'layer_S': [2, 2], 'feature': 'hidden', 'loss': 'nst', 'weight': 1 }])
from textbrewer import TrainingConfig, DistillationConfig # We omit the initialization of models, optimizer, and dataloader. teacher_model : torch.nn.Module = ... student_model : torch.nn.Module = ... dataloader : torch.utils.data.DataLoader = ... optimizer : torch.optim.Optimizer = ... scheduler : torch.optim.lr_scheduler = ... def simple_adaptor(batch, model_outputs): # We assume that the first element of model_outputs # is the logits before softmax return {'logits': model_outputs[0]} train_config = TrainingConfig() distill_config = DistillationConfig() distiller = GeneralDistiller( train_config=train_config, distill_config = distill_config, model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor) distiller.train(optimizer, scheduler, dataloader, num_epochs, callback=None) def predict(model, eval_dataset, step, args): raise NotImplementedError # fill other arguments
def main(): config_distil.parse() global args args = config_distil.args global logger logger = create_logger(args.log_file) for k, v in vars(args).items(): logger.info(f"{k}:{v}") # set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) logger.info('加载字典') word2idx = load_chinese_base_vocab() # 判断是否有可用GPU args.device = torch.device( "cuda" if torch.cuda.is_available() and args.is_cuda else "cpu") logger.info('using device:{}'.format(args.device)) # 定义模型超参数 bertconfig_T = BertConfig(vocab_size=len(word2idx)) bertconfig_S = BertConfig(vocab_size=len(word2idx), num_hidden_layers=3) logger.info('初始化BERT模型') bert_model_T = Seq2SeqModel(config=bertconfig_T) bert_model_S = Seq2SeqModel(config=bertconfig_S) logger.info('加载Teacher模型~') load_model(bert_model_T, args.tuned_checkpoint_T) logger.info('将模型发送到计算设备(GPU或CPU)') bert_model_T.to(args.device) bert_model_T.eval() logger.info('加载Student模型~') if args.load_model_type == 'bert': load_model(bert_model_S, args.init_checkpoint_S) else: logger.info(" Student Model is randomly initialized.") logger.info('将模型发送到计算设备(GPU或CPU)') bert_model_S.to(args.device) # 声明自定义的数据加载器 logger.info('加载训练数据') train = SelfDataset(args.train_path, args.max_length) trainloader = DataLoader(train, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn) if args.do_train: logger.info(' 声明需要优化的参数') num_train_steps = int( len(trainloader) / args.train_batch_size) * args.num_train_epochs optim_parameters = list(bert_model_S.named_parameters()) all_trainable_params = divide_parameters(optim_parameters, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps, schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device) from generative.matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"{intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) def BertForS2SSimpleAdaptor(batch, model_outputs): return {'hidden': model_outputs[0], 'logits': model_outputs[1], 'loss': model_outputs[2], 'attention': model_outputs[3]} adaptor_T = partial(BertForS2SSimpleAdaptor) adaptor_S = partial(BertForS2SSimpleAdaptor) distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=bert_model_T, model_S=bert_model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) callback_func = partial(predict, data_path=args.dev_path, args=args) logger.info('Start distillation.') with distiller: distiller.train(optimizer, scheduler=None, dataloader=trainloader, num_epochs=args.num_train_epochs, callback=None) if not args.do_train and args.do_predict: res = predict(bert_model_S, args.test_path, step=0, args=args) print(res)
def main(): #parse arguments config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = BertConfig.from_json_file(args.bert_config_file_T) bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #Prepare GLUE task processor = processors[args.task_name]() args.output_mode = output_modes[args.task_name] label_list = processor.get_labels() num_labels = len(label_list) #read data train_dataset = None eval_datasets = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) if args.aux_task_name: aux_train_dataset = load_and_cache_examples(args, args.aux_task_name, tokenizer, evaluate=False, is_aux=True) train_dataset = torch.utils.data.ConcatDataset([train_dataset, aux_train_dataset]) num_train_steps = int(len(train_dataset)/args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_datasets = [] eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) for eval_task in eval_task_names: eval_datasets.append(load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)) logger.info("Data loaded") #Build Model and load checkpoint model_S = BertForGLUESimple(bert_config_S, num_labels=num_labels,args=args) #Load teacher if args.tuned_checkpoint_Ts: model_Ts = [BertForGLUESimple(bert_config_T, num_labels=num_labels,args=args) for i in range(len(args.tuned_checkpoint_Ts))] for model_T, ckpt_T in zip(model_Ts,args.tuned_checkpoint_Ts): logger.info("Load state dict %s" % ckpt_T) state_dict_T = torch.load(ckpt_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type=='bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) assert len(missing_keys)==0 elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") if args.do_train: for model_T in model_Ts: model_T.to(device) model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: if args.do_train: model_Ts = [torch.nn.DataParallel(model_T) for model_T in model_Ts] model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) distill_config = DistillationConfig( temperature = args.temperature, kd_loss_type = 'ce') logger.info(f"{train_config}") logger.info(f"{distill_config}") adaptor = partial(BertForGLUESimpleAdaptor, no_logits=False, no_mask = False) distiller = MultiTeacherDistiller(train_config = train_config, distill_config = distill_config, model_T = model_Ts, model_S = model_S, adaptor_T=adaptor, adaptor_S=adaptor) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) callback_func = partial(predict, eval_datasets=eval_datasets, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_datasets,step=0,args=args) print (res)
def main(): #parse arguments config.parse() args = config.args logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger = logging.getLogger("Main") #arguments check device, n_gpu = args_check(logger, args) if args.local_rank in [-1, 0]: os.makedirs(args.output_dir, exist_ok=True) if args.local_rank != -1: logger.warning( f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}" ) for k, v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_T = ElectraConfig.from_json_file(args.bert_config_file_T) bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S) bert_config_S.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_T.output_hidden_states = (args.output_encoded_layers == 'true') bert_config_S.num_labels = len(label2id_dict) bert_config_T.num_labels = len(label2id_dict) assert args.max_seq_length <= bert_config_T.max_position_embeddings assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_dataset = None eval_examples = None eval_dataset = None num_train_steps = None tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) if args.local_rank not in [-1, 0]: torch.distributed.barrier() if args.do_train: train_examples, train_dataset = read_features( args.train_file, max_seq_length=args.max_seq_length) if args.do_predict: eval_examples, eval_dataset = read_features( args.predict_file, max_seq_length=args.max_seq_length) if args.local_rank == 0: torch.distributed.barrier() #Build Model and load checkpoint model_T = ElectraForTokenClassification(bert_config_T) model_S = ElectraForTokenClassification(bert_config_S) #Load teacher if args.tuned_checkpoint_T is not None: state_dict_T = torch.load(args.tuned_checkpoint_T, map_location='cpu') model_T.load_state_dict(state_dict_T) model_T.eval() else: assert args.do_predict is True #Load student if args.load_model_type == 'bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S, strict=False) logger.info(f"missing keys:{missing_keys}") logger.info(f"unexpected keys:{unexpected_keys}") elif args.load_model_type == 'all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S, map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_T.to(device) model_S.to(device) if args.do_train: #parameters if args.lr_decay is not None: outputs_params = list(model_S.classifier.named_parameters()) outputs_params = divide_parameters(outputs_params, lr=args.learning_rate) electra_params = [] n_layers = len(model_S.electra.encoder.layer) assert n_layers == 12 for i, n in enumerate(reversed(range(n_layers))): encoder_params = list( model_S.electra.encoder.layer[n].named_parameters()) lr = args.learning_rate * args.lr_decay**(i + 1) electra_params += divide_parameters(encoder_params, lr=lr) logger.info(f"{i},{n},{lr}") embed_params = [ (name, value) for name, value in model_S.electra.named_parameters() if 'embedding' in name ] logger.info(f"{[name for name,value in embed_params]}") lr = args.learning_rate * args.lr_decay**(n_layers + 1) electra_params += divide_parameters(embed_params, lr=lr) logger.info(f"embed lr:{lr}") all_trainable_params = outputs_params + electra_params assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\ (sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters()))) else: params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size, drop_last=True) num_train_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias=False) if args.official_schedule == 'const': scheduler_class = get_constant_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps) } #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps)) elif args.official_schedule == 'linear': scheduler_class = get_linear_schedule_with_warmup scheduler_args = { 'num_warmup_steps': int(args.warmup_proportion * num_train_steps), 'num_training_steps': num_train_steps } #scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps) else: raise NotImplementedError logger.warning("***** Running training *****") logger.warning("local_rank %d Num orig examples = %d", args.local_rank, len(train_examples)) logger.warning("local_rank %d Num split examples = %d", args.local_rank, len(train_dataset)) logger.warning("local_rank %d Forward batch size = %d", args.local_rank, forward_batch_size) logger.warning("local_rank %d Num backward steps = %d", args.local_rank, num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps=args.gradient_accumulation_steps, ckpt_frequency=args.ckpt_frequency, log_dir=args.output_dir, output_dir=args.output_dir, device=args.device, fp16=args.fp16, local_rank=args.local_rank) logger.info(f"{train_config}") from matches import matches intermediate_matches = None if isinstance(args.matches, (list, tuple)): intermediate_matches = [] for match in args.matches: intermediate_matches += matches[match] logger.info(f"{intermediate_matches}") distill_config = DistillationConfig( temperature=args.temperature, intermediate_matches=intermediate_matches) adaptor_T = ElectraForTokenClassificationAdaptor adaptor_S = ElectraForTokenClassificationAdaptor distiller = GeneralDistiller(train_config=train_config, distill_config=distill_config, model_T=model_T, model_S=model_S, adaptor_T=adaptor_T, adaptor_S=adaptor_S) # evluate the model in a single process in ddp_predict callback_func = partial(ddp_predict, eval_examples=eval_examples, eval_dataset=eval_dataset, args=args) with distiller: distiller.train(optimizer, scheduler_class=scheduler_class, scheduler_args=scheduler_args, max_grad_norm=1.0, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = ddp_predict(model_S, eval_examples, eval_dataset, step=0, args=args) print(res)