def main(): parser = argparse.ArgumentParser(description='pruning_one-step.py') parser.add_argument('-model_path', default='../KD/models/bert_ft', type=str, help="distill type") parser.add_argument('-output_dir', default='models/prun_bert', type=str, help="output dir") parser.add_argument('-task', default='CoLA', type=str, help="Name of the task") parser.add_argument('-keep_heads', type=int, default=2, help="the number of attention heads to keep") parser.add_argument('-ffn_hidden_dim', type=int, default=512, help="Hidden size of the FFN subnetworks.") parser.add_argument('-num_layers', type=int, default=8, help="the number of layers of the pruned model") parser.add_argument('-emb_hidden_dim', type=int, default=128, help="Hidden size of embedding factorization. \ Do not factorize embedding if value==-1") args = parser.parse_args() torch.manual_seed(0) args.model_path = os.path.join(args.model_path, args.task) args.output_dir = os.path.join(args.output_dir, args.task) print('Loading BERT from %s...' % args.model_path) model = PrunTinyBertForSequenceClassification.from_pretrained( args.model_path, num_labels=num_labels[args.task.lower()]) config = model.config tokenizer = BertTokenizer.from_pretrained(args.model_path, do_lower_case=True) model.bert.encoder.layer = torch.nn.ModuleList( [model.bert.encoder.layer[i] for i in range(args.num_layers)]) if args.ffn_hidden_dim>config.prun_intermediate_size or \ (args.emb_hidden_dim>config.emb_hidden_dim and config.emb_hidden_dim!=-1): raise ValueError('Cannot prune the model to a larger size!') args.prun_ratio = args.ffn_hidden_dim / config.prun_intermediate_size print( 'Pruning to %d heads, %d layers, %d FFN hidden dim, %d emb hidden dim...' % (args.keep_heads, args.num_layers, args.ffn_hidden_dim, args.emb_hidden_dim)) importance_dir = os.path.join(args.model_path, 'taylor_score', 'taylor.pkl') new_config = BertConfigPrun(num_attention_heads=args.keep_heads, prun_hidden_size=int(args.keep_heads * 64), prun_intermediate_size=args.ffn_hidden_dim, num_hidden_layers=args.num_layers, emb_hidden_dim=args.emb_hidden_dim) model = Taylor_pruning_structured(model, args.prun_ratio, config.num_attention_heads, args.keep_heads, importance_dir, args.emb_hidden_dim, new_config) output_dir = os.path.join( args.output_dir, 'a%d_l%d_f%d_e%d' % (args.keep_heads, args.num_layers, args.ffn_hidden_dim, args.emb_hidden_dim)) print('Saving model to %s' % output_dir) if not os.path.exists(output_dir): os.makedirs(output_dir) torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) new_config.save_pretrained(output_dir) tokenizer.save_vocabulary(output_dir) model = PrunTinyBertForSequenceClassification.from_pretrained( output_dir, num_labels=num_labels[args.task.lower()]) torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin')) print( "Number of parameters: %d" % sum([model.state_dict()[key].nelement() for key in model.state_dict()])) print(model.state_dict().keys())
type=int, default=[1, 12]) parser.add_argument('--intermediate_size_space', nargs='+', type=int, default=[128, 3072]) parser.add_argument('--mlm', action='store_true') parser.add_argument('--infer_cnt', type=int, default=10) args = parser.parse_args() config = BertConfig.from_pretrained( os.path.join(args.bert_model, 'config.json')) model = SuperTinyBertForPreTraining.from_scratch(args.bert_model, config) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True) device = 'cpu' model.to(device) model.eval() torch.set_num_threads(1) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # build arch space min_hidden_size, max_hidden_size = args.hidden_size_space min_ffn_size, max_ffn_size = args.intermediate_size_space min_qkv_size, max_qkv_size = args.qkv_size_space
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--pretrain_model_name_or_path", default=None, type=str, help="The pretrain model name or path.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument("--domain", default='all', type=str, required=True, help="The domain of given model.") parser.add_argument("--use_domain_loss", default=False, type=bool, help="Whether to use domain loss.") parser.add_argument("--data_portion", default=1.0, type=float, required=False, help="How many data selected.") parser.add_argument("--domain_loss_weight", default=0.2, type=float, help="The loss weight of domain.") parser.add_argument("--use_sample_weights", default=False, type=bool, help="The loss weight of domain.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=int, default=50) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) processors = { "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "senti": SentiProcessor } output_modes = {"mnli": "classification", "senti": "classification"} if args.task_name.lower() == "mnli": domain_idx_mapping = { domain: idx for idx, domain in enumerate( "telephone,government,slate,fiction,travel".split(",")) } else: domain_idx_mapping = { domain: idx for idx, domain in enumerate("books,dvd,electronics,kitchen".split( ",")) } num_domains = len(domain_idx_mapping) # intermediate distillation default parameters default_params = { "mnli": { "num_train_epochs": 5, "max_seq_length": 128 }, "senti": { "num_train_epochs": 5, "max_seq_length": 128 }, } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte", "senti"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings # if os.path.exists(args.output_dir) and os.listdir(args.output_dir): # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name in default_params: args.max_seq_len = default_params[task_name]["max_seq_length"] if not args.do_eval: if task_name in default_params: args.num_train_epoch = default_params[task_name][ "num_train_epochs"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name](portion=args.data_portion) output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.pretrain_model_name_or_path, do_lower_case=args.do_lower_case) if not args.do_eval: if not args.aug_train: train_examples = processor.get_train_examples( args.data_dir, args.domain) else: train_examples = processor.get_aug_examples( args.data_dir, args.domain) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps num_train_optimization_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs portion_str = "_{}".format( args.data_portion) if args.data_portion != 1.0 else "" meta_str = "meta" if args.use_domain_loss or args.use_sample_weights else "" cached_train_path = os.path.join( args.data_dir, "cached_train_features_{}{}{}{}.pt".format( args.domain, meta_str, "_with_weights" if args.use_sample_weights else "", portion_str)) if os.path.exists(cached_train_path): train_features = torch.load(cached_train_path) else: train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer, output_mode, domain_idx_mapping) torch.save(train_features, cached_train_path) print("Save to cached path %s" % cached_train_path) train_data, _ = get_tensor_data(output_mode, train_features) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) if args.do_eval: eval_examples = processor.get_test_examples(args.data_dir, args.domain) else: eval_examples = processor.get_dev_examples(args.data_dir, args.domain) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, domain_idx_mapping) eval_data, eval_labels = get_tensor_data(output_mode, eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) meta_teacher_model = MetaTeacherForSequenceClassification.from_pretrained( args.pretrain_model_name_or_path, num_labels=num_labels, num_domains=num_domains) meta_teacher_model.to(device) if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) meta_teacher_model.eval() result = do_eval(meta_teacher_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: meta_teacher_model = torch.nn.DataParallel(meta_teacher_model) # Prepare optimizer param_optimizer = list(meta_teacher_model.named_parameters()) size = 0 for n, p in meta_teacher_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) # Train and evaluate global_step = 0 best_dev_acc = 0.0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") ce_loss_fn = CrossEntropyLoss(reduction="none") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_cls_loss = 0. meta_teacher_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths, domain_ids, sample_weights = batch if input_ids.size()[0] != args.train_batch_size: continue logits, domain_logits, *_ = meta_teacher_model( input_ids, segment_ids, input_mask, domain_ids) losses = ce_loss_fn(logits, label_ids) if args.use_domain_loss: shuffled_domain_ids = domain_ids[torch.randperm( domain_ids.shape[0])] domain_losses = ce_loss_fn(domain_logits, shuffled_domain_ids) losses += args.domain_loss_weight * domain_losses if args.use_sample_weights: loss = torch.mean(losses * sample_weights) else: loss = torch.mean(losses) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % args.eval_step == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) meta_teacher_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) result = do_eval(meta_teacher_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['loss'] = loss result_to_file(result, output_eval_file) save_model = False if task_name in acc_tasks and result['acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result[ 'corr'] > best_dev_acc: best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result['mcc'] > best_dev_acc: best_dev_acc = result['mcc'] save_model = True if save_model: logger.info("***** Save model *****") model_to_save = meta_teacher_model.module if hasattr(meta_teacher_model, 'module') \ else meta_teacher_model model_name = WEIGHTS_NAME output_model_file = os.path.join( args.output_dir, model_name) output_config_file = os.path.join( args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info( mox.file.list_directory(args.output_dir, recursive=True)) logging.info( mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url) meta_teacher_model.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files or the task.") parser.add_argument("--teacher_model", default=None, type=str, help="The teacher model dir.") parser.add_argument("--student_model", default=None, type=str, required=True, help="The student model dir.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where model checkpoints will be written.") parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--max_seq_len", default=128, type=int, help="The maximum total input sequence length ") parser.add_argument("--num_labels", default=2, type=int, required=True, help="") parser.add_argument("--task_mode", default='classification', type=str, required=False, help="task type") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_train", action='store_true', help="Whether to run train on the train set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate") # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=int, default=50) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps num_labels = args.num_labels tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) if args.do_train: train_path = os.path.join(args.data_dir, 'train.txt') eval_path = os.path.join(args.data_dir, 'eval.txt') train_examples = read_examples(train_path) eval_examples = read_examples(eval_path) num_train_optimization_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs train_features = convert_examples_to_features(train_examples, tokenizer, args.max_seq_len) eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_len) train_features = MyDataLoader(train_features) eval_features = MyDataLoader(eval_features) train_dataloader = DataLoader(train_features, shuffle=True, batch_size=args.train_batch_size) # eval_dataloader = DataLoader(eval_features, shuffle=False, batch_size=args.eval_batch_size) teacher_model = TinyBertForSequenceClassification.from_pretrained( args.teacher_model, num_labels=num_labels) teacher_model.to(device) student_model = TinyBertForSequenceClassification.from_pretrained( args.student_model, num_labels=num_labels) student_model.to(device) # 只做预测 if args.do_train: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' if not args.pred_distill: schedule = 'none' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask, is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) # 第一阶段 if not args.pred_distill: teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [ teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num) ] for student_att, teacher_att in zip( student_atts, new_teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss new_teacher_reps = [ teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) ] new_student_reps = student_reps for student_rep, teacher_rep in zip( new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss = rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() # 第二阶段 else: if args.task_mode == "classification": cls_loss = soft_cross_entropy( student_logits / args.temperature, teacher_logits / args.temperature) elif args.task_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss tr_cls_loss += cls_loss.item() if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % args.eval_step == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss result_to_file(result, output_eval_file) logger.info("***** Save model *****") model_to_save = student_model.module if hasattr( student_model, 'module') else student_model model_name = f'{epoch_}_{WEIGHTS_NAME}' output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir)
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--pregenerated_data", type=Path, required=True) parser.add_argument("--teacher_model", default=None, type=str, required=True) parser.add_argument("--student_model", default=None, type=str, required=True) parser.add_argument("--output_dir", default=None, type=str, required=True) # Other parameters parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--reduce_memory", action="store_true", help="Store training data as on-disc memmaps to massively reduce memory usage") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--continue_train', action='store_true', help='Whether to train from checkpoints') # Additional arguments parser.add_argument('--eval_step', type=int, default=1000) # This is used for running on Huawei Cloud. parser.add_argument('--data_url', type=str, default="") args = parser.parse_args() logger.info('args:{}'.format(args)) samples_per_epoch = [] for i in range(int(args.num_train_epochs)): epoch_file = args.pregenerated_data / "epoch_{}.json".format(i) metrics_file = args.pregenerated_data / "epoch_{}_metrics.json".format(i) if epoch_file.is_file() and metrics_file.is_file(): metrics = json.loads(metrics_file.read_text()) samples_per_epoch.append(metrics['num_training_examples']) else: if i == 0: exit("No training data was found!") print("Warning! There are fewer epochs of pregenerated data ({}) than training epochs ({}).".format(i, args.num_train_epochs)) print("This script will loop over the available data, but training diversity may be negatively impacted.") num_data_epochs = i break else: num_data_epochs = args.num_train_epochs if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case) total_train_examples = 0 for i in range(int(args.num_train_epochs)): # The modulo takes into account the fact that we may loop over limited epochs of data total_train_examples += samples_per_epoch[i % len(samples_per_epoch)] num_train_optimization_steps = int( total_train_examples / args.train_batch_size / args.gradient_accumulation_steps) if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() if args.continue_train: student_model = TinyBertForPreTraining.from_pretrained(args.student_model) else: student_model = TinyBertForPreTraining.from_scratch(args.student_model) teacher_model = BertModel.from_pretrained(args.teacher_model) # student_model = TinyBertForPreTraining.from_scratch(args.student_model, fit_size=teacher_model.config.hidden_size) student_model.to(device) teacher_model.to(device) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") teacher_model = DDP(teacher_model) elif n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) logger.info('p: {}'.format(p.nelement())) size += p.nelement() logger.info('Total parameters: {}'.format(size)) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] loss_mse = MSELoss() optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) global_step = 0 logging.info("***** Running training *****") logging.info(" Num examples = {}".format(total_train_examples)) logging.info(" Batch size = %d", args.train_batch_size) logging.info(" Num steps = %d", num_train_optimization_steps) for epoch in trange(int(args.num_train_epochs), desc="Epoch"): epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer, num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory) if args.local_rank == -1: train_sampler = RandomSampler(epoch_dataset) else: train_sampler = DistributedSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 with tqdm(total=len(train_dataloader), desc="Epoch {}".format(epoch)) as pbar: for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. student_atts, student_reps = student_model(input_ids, segment_ids, input_mask) teacher_reps, teacher_atts, _ = teacher_model(input_ids, segment_ids, input_mask) teacher_reps = [teacher_rep.detach() for teacher_rep in teacher_reps] # speedup 1.5x teacher_atts = [teacher_att.detach() for teacher_att in teacher_atts] teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] for student_att, teacher_att in zip(student_atts, new_teacher_atts): student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) att_loss += loss_mse(student_att, teacher_att) new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)] new_student_reps = student_reps for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps): rep_loss += loss_mse(student_rep, teacher_rep) loss = att_loss + rep_loss if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) else: loss.backward() tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 pbar.update(1) mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps mean_att_loss = tr_att_loss * args.gradient_accumulation_steps / nb_tr_steps mean_rep_loss = tr_rep_loss * args.gradient_accumulation_steps / nb_tr_steps if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % args.eval_step == 0: result = {} result['global_step'] = global_step result['loss'] = mean_loss result['att_loss'] = mean_att_loss result['rep_loss'] = mean_rep_loss output_eval_file = os.path.join(args.output_dir, "log.txt") with open(output_eval_file, "a") as writer: logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) # Save a trained model model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME) logging.info("** ** * Saving fine-tuned model ** ** * ") # Only save the model it-self model_to_save = student_model.module if hasattr(student_model, 'module') else student_model output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info(mox.file.list_directory(args.output_dir, recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url) model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME) logging.info("** ** * Saving fine-tuned model ** ** * ") model_to_save = student_model.module if hasattr(student_model, 'module') else student_model output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info(mox.file.list_directory(args.output_dir, recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--teacher_model", default=None, type=str, help="The teacher model dir.") parser.add_argument("--student_model", default=None, type=str, required=True, help="The student model dir.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=int, default=50) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() logger.info('The args: {}'.format(args)) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor, "wnli": WnliProcessor } output_modes = { "cola": "classification", "mnli": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification", "wnli": "classification" } # intermediate distillation default parameters default_params = { "cola": { "num_train_epochs": 50, "max_seq_length": 64 }, "mnli": { "num_train_epochs": 5, "max_seq_length": 128 }, "mrpc": { "num_train_epochs": 20, "max_seq_length": 128 }, "sst-2": { "num_train_epochs": 10, "max_seq_length": 64 }, "sts-b": { "num_train_epochs": 20, "max_seq_length": 128 }, "qqp": { "num_train_epochs": 5, "max_seq_length": 128 }, "qnli": { "num_train_epochs": 10, "max_seq_length": 128 }, "rte": { "num_train_epochs": 20, "max_seq_length": 128 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name in default_params: args.max_seq_len = default_params[task_name]["max_seq_length"] if not args.pred_distill and not args.do_eval: if task_name in default_params: args.num_train_epoch = default_params[task_name][ "num_train_epochs"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) if not args.do_eval: #if not args.aug_train: # train_examples = processor.get_train_examples(args.data_dir) #else: # train_examples = processor.get_aug_examples(args.data_dir) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps # rewrite data processing here assert args.task_name == "MNLI", "the script is designed for MNLI only now" mnli_datasets = load_dataset("text", data_files=os.path.join( args.data_dir, "train_aug.tsv")) label_classes = processor.get_labels() label_map = {label: i for i, label in enumerate(label_classes)} def preprocess_func(examples, max_seq_length=args.max_seq_length): splits = [e.split('\t') for e in examples['text']] # tokenize for sent1 & sent2 tokens_s1 = [tokenizer.tokenize(e[8]) for e in splits] tokens_s2 = [tokenizer.tokenize(e[9]) for e in splits] for t1, t2 in zip(tokens_s1, tokens_s2): truncate_seq_pair(t1, t2, max_length=max_seq_length - 3) input_ids_list = [] input_mask_list = [] segment_ids_list = [] seq_length_list = [] labels_list = [] labels = [e[-1] for e in splits] # print(len(labels)) for token_a, token_b, l in zip( tokens_s1, tokens_s2, labels): # zip(tokens_as, tokens_bs): tokens = ["[CLS]"] + token_a + ["[SEP]"] segment_ids = [0] * len(tokens) tokens += token_b + ["[SEP]"] segment_ids += [1] * (len(token_b) + 1) input_ids = tokenizer.convert_tokens_to_ids( tokens) # tokenize to id input_mask = [1] * len(input_ids) seq_length = len(input_ids) padding = [0] * (max_seq_length - len(input_ids)) input_ids += padding input_mask += padding segment_ids += padding assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length input_ids_list.append(input_ids) input_mask_list.append(input_mask) segment_ids_list.append(segment_ids) seq_length_list.append(seq_length) labels_list.append(label_map[l]) results = { "input_ids": input_ids_list, "input_mask": input_mask_list, "segment_ids": segment_ids_list, "seq_length": seq_length_list, "label_ids": labels_list } return results mnli_datasets = mnli_datasets.map(preprocess_func, batched=True) # train_features = convert_examples_to_features(train_examples, label_list, # args.max_seq_length, tokenizer, output_mode, logger) train_data = mnli_datasets['train'].remove_columns('text') print(train_data[0]) # train_data, _ = get_tensor_data(output_mode, train_features) num_train_optimization_steps = int( len(train_data) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs logger.info("Initializing Distributed Environment") torch.cuda.set_device(args.local_rank) dist.init_process_group(backend="nccl") train_sampler = torch.utils.data.DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, logger) eval_data, eval_labels = get_tensor_data(output_mode, eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) # DDP setting local_rank = args.local_rank torch.cuda.set_device(local_rank) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") student_model = TinyBertForSequenceClassification.from_pretrained( args.student_model, num_labels=num_labels).to(device) if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_data)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) teacher_model = TinyBertForSequenceClassification.from_pretrained( args.teacher_model, num_labels=num_labels).to(device) student_model = DDP(student_model, device_ids=[local_rank], output_device=local_rank) teacher_model = DDP(teacher_model, device_ids=[local_rank], output_device=local_rank) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' if not args.pred_distill: schedule = 'none' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) scaler = torch.cuda.amp.GradScaler() # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 best_dev_acc = 0.0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): # optimizer.zero_grad() #batch = tuple(torch.tensor(t, dtype=torch.long).to(device) for t in batch) # print(batch) inputs = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(device) elif isinstance(v, List): inputs[k] = torch.stack(v, dim=1).to(device) # inputs = {k: torch.tensor(v, dtype=torch.long).to(device) for k, v in batch.items()} # input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch # print([(k, inputs[k].size()) for k in inputs]) if inputs['input_ids'].size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. with autocast(): student_logits, student_atts, student_reps = student_model( inputs['input_ids'], inputs['segment_ids'], inputs['input_mask'], is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( inputs['input_ids'], inputs['segment_ids'], inputs['input_mask']) if not args.pred_distill: teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [ teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num) ] for student_att, teacher_att in zip( student_atts, new_teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss new_teacher_reps = [ teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) ] new_student_reps = student_reps for student_rep, teacher_rep in zip( new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss # add this term for amp detection loss = rep_loss + att_loss + 0 * soft_cross_entropy( student_logits / args.temperature, teacher_logits / args.temperature) tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() else: if output_mode == "classification": cls_loss = soft_cross_entropy( student_logits / args.temperature, teacher_logits / args.temperature) elif output_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss + 0 * loss_mse( student_atts[0], teacher_atts[0]) + 0 * loss_mse( teacher_reps[0], student_reps[0]) tr_cls_loss += cls_loss.item() # if n_gpu > 1: # loss = loss.mean() # mean() to average on multi-gpu. # if args.gradient_accumulation_steps > 1: # loss = loss / args.gradient_accumulation_steps scaler.scale(loss).backward() # loss.backward() tr_loss += loss.item() nb_tr_examples += inputs['label_ids'].size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: # optimizer.step() scaler.step(optimizer) scaler.update() optimizer.zero_grad() global_step += 1 if (global_step + 1) % args.eval_step == 0 and args.local_rank == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} if args.pred_distill: result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss result_to_file(result, output_eval_file) if not args.pred_distill: save_model = True else: save_model = False if task_name in acc_tasks and result[ 'acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result[ 'corr'] > best_dev_acc: best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result[ 'mcc'] > best_dev_acc: best_dev_acc = result['mcc'] save_model = True if save_model and args.local_rank == 0: logger.info("***** Save model *****") model_to_save = student_model.module if hasattr( student_model, 'module') else student_model model_name = WEIGHTS_NAME # if not args.pred_distill: # model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME) output_model_file = os.path.join( args.output_dir, model_name) output_config_file = os.path.join( args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Test mnli-mm if args.pred_distill and task_name == "mnli": task_name = "mnli-mm" processor = processors[task_name]() if not os.path.exists(args.output_dir + '-MM'): os.makedirs(args.output_dir + '-MM') eval_examples = processor.get_dev_examples( args.data_dir) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, logger) eval_data, eval_labels = get_tensor_data( output_mode, eval_features) logger.info("***** Running mm evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader( eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step tmp_output_eval_file = os.path.join( args.output_dir + '-MM', "eval_results.txt") result_to_file(result, tmp_output_eval_file) task_name = 'mnli' if oncloud: logging.info( mox.file.list_directory(args.output_dir, recursive=True)) logging.info( mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url) student_model.train()
def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--bert_model", type=str, required=True) parser.add_argument("--do_lower_case", action="store_true") parser.add_argument( "--do_whole_word_mask", action="store_true", help= "Whether to use whole word masking rather than per-WordPiece masking.") parser.add_argument( "--reduce_memory", action="store_true", help= "Reduce memory usage for large datasets by keeping data on disc rather than in memory" ) parser.add_argument("--num_workers", type=int, default=1, help="The number of workers to use to write the files") parser.add_argument("--epochs_to_generate", type=int, default=3, help="Number of epochs of data to pregenerate") parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument( "--short_seq_prob", type=float, default=0.1, help="Probability of making a short sentence as a training example") parser.add_argument( "--masked_lm_prob", type=float, default=0.0, help="Probability of masking each token for the LM task" ) # no [mask] symbol in corpus parser.add_argument( "--max_predictions_per_seq", type=int, default=20, help="Maximum number of tokens to mask in each sequence") parser.add_argument('--data_url', type=str, default="") parser.add_argument('--one_seq', action='store_true') args = parser.parse_args() if args.num_workers > 1 and args.reduce_memory: raise ValueError("Cannot use multiple workers while reducing memory") tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) vocab_list = list(tokenizer.vocab.keys()) doc_num = 0 with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: import os for root, dirs, files in os.walk(args.train_corpus, topdown=False): for name in files: logger.info(f'Start on {Path(root, name)}') with Path(root, name).open() as f: doc = [] for line in tqdm(f, desc="Loading Dataset", unit=" lines"): line = line.strip() if line == "": docs.add_document(doc) doc = [] doc_num += 1 if doc_num % 100 == 0: logger.info('loaded {} docs!'.format(doc_num)) else: tokens = tokenizer.tokenize(line) doc.append(tokens) if doc: docs.add_document( doc ) # If the last doc didn't end on a newline, make sure it still gets added if len(docs) <= 1: exit( "ERROR: No document breaks were found in the input file! These are necessary to allow the script to " "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " "indicate breaks between documents in your input file. If your dataset does not contain multiple " "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " "sections or paragraphs.") args.output_dir.mkdir(exist_ok=True) if args.num_workers > 1: writer_workers = Pool( min(args.num_workers, args.epochs_to_generate)) arguments = [(docs, vocab_list, args, idx) for idx in range(args.epochs_to_generate)] writer_workers.starmap(create_training_file, arguments) else: for epoch in trange(args.epochs_to_generate, desc="Epoch"): bi_text = True if not args.one_seq else False epoch_file, metric_file = create_training_file(docs, vocab_list, args, epoch, bi_text=bi_text) if oncloud: logging.info( mox.file.list_directory(str(args.output_dir), recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(str(args.output_dir), args.data_url) mox.file.copy_parallel('.', args.data_url) os.remove(str(epoch_file)) os.remove(str(metric_file))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--teacher_model", default=None, type=str, help="The teacher model dir.") parser.add_argument("--student_model", default=None, type=str, required=True, help="The student model dir.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=int, default=50) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) wandb.config.update(args) processors = { "race": RaceProcessor, } # intermediate distillation default parameters default_params = { "race": {"num_train_epochs": 3, "max_seq_length": 80}, } # Prepare devices device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name in default_params: args.max_seq_len = default_params[task_name]["max_seq_length"] if not args.pred_distill and not args.do_eval: if task_name in default_params: args.num_train_epoch = default_params[task_name]["num_train_epochs"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) if not args.do_eval: if not args.aug_train: train_examples = processor.get_train_examples(args.data_dir) else: train_examples = processor.get_aug_examples(args.data_dir) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps num_train_optimization_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs cached_features_file_train = os.path.join( args.data_dir, "cached_train_{}_{}_{}_tinybert".format(tokenizer.__class__.__name__, str(args.max_seq_length), task_name, ), ) if os.path.exists(cached_features_file_train): train_features = torch.load(cached_features_file_train) else: train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer) torch.save(train_features, cached_features_file_train) train_data, _ = get_tensor_data(train_features) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) eval_examples = processor.get_dev_examples(args.data_dir) cached_features_file_eval = os.path.join( args.data_dir, "cached_dev_{}_{}_{}_tinybert".format(tokenizer.__class__.__name__, str(args.max_seq_length), task_name, ), ) if os.path.exists(cached_features_file_eval): eval_features = torch.load(cached_features_file_eval) else: eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer) torch.save(eval_features, cached_features_file_eval) eval_data, eval_labels = get_tensor_data(eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) if not args.do_eval: teacher_model = TinyBertForMultipleChoice.from_pretrained(args.teacher_model) teacher_model.to(device) student_model = TinyBertForMultipleChoice.from_pretrained(args.student_model) student_model.to(device) wandb.watch(student_model, log='all') if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() result = do_eval(student_model, task_name, eval_dataloader, device, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) wandb.log(result) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] schedule = 'warmup_linear' if not args.pred_distill: schedule = 'none' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) if args.fp16: if not _has_apex: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") student_model, optimizer = amp.initialize(student_model, optimizer, opt_level='O1') if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (- targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 best_dev_acc = 0.0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. student_logits, student_atts, student_reps = student_model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) if not args.pred_distill: teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)] for student_att, teacher_att in zip(student_atts, new_teacher_atts): student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)] new_student_reps = student_reps for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss = rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() else: cls_loss = soft_cross_entropy(student_logits / args.temperature, teacher_logits / args.temperature) loss = cls_loss tr_cls_loss += cls_loss.item() if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1) optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % args.eval_step == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format(epoch_, global_step)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} if args.pred_distill: result = do_eval(student_model, task_name, eval_dataloader, device, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss wandb.log(result, step=global_step) result_to_file(result, output_eval_file) if not args.pred_distill: save_model = True else: save_model = False if result['acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if save_model: logger.info("***** Save model *****") model_to_save = student_model.module if hasattr(student_model, 'module') else student_model model_name = WEIGHTS_NAME # if not args.pred_distill: # model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if oncloud: logging.info(mox.file.list_directory(args.output_dir, recursive=True)) logging.info(mox.file.list_directory('.', recursive=True)) mox.file.copy_parallel(args.output_dir, args.data_url) mox.file.copy_parallel('.', args.data_url) student_model.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, required=True, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--model", default=None, type=str, required=True, help="The model dir.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.06, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument('--weight_bit', type=int, default=4, help="Number of bits for weight.") parser.add_argument('--quant_group_number', type=int, default=1, help="Number of bits for weight.") parser.add_argument('--activation_bit', type=int, default=8, help="Number of bits for weight.") # added arguments parser.add_argument('--aug_train', type=str, default='none', help="Whether to train with augmented data.") parser.add_argument('--eval_step', type=int, default=50) parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) parser.add_argument('--train_name', type=str, default="") parser.add_argument('--val_name', type=str, default="") args = parser.parse_args() logger.info('The args: {}'.format(args)) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor, "wnli": WnliProcessor } output_modes = { "cola": "classification", "mnli": "classification", "mrpc": "classification", "sst-2": "classification", #"sst-2": "regression", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification", "wnli": "classification" } # intermediate distillation default parameters default_params = { "cola": { "num_train_epochs": 10, "max_seq_length": 64, "learning_rate": 2e-5, "train_batch_size": 32 }, "sst-2": { "num_train_epochs": 10, "max_seq_length": 64, "learning_rate": 2e-5, "train_batch_size": 32 }, "mnli": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 1e-5, "train_batch_size": 32 }, "mrpc": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 1e-5, "train_batch_size": 32 }, "sts-b": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 2e-5, "train_batch_size": 16 }, "qqp": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 1e-5, "train_batch_size": 32 }, "qnli": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 1e-5, "train_batch_size": 32 }, "rte": { "num_train_epochs": 10, "max_seq_length": 128, "learning_rate": 2e-5, "train_batch_size": 16 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if os.path.exists(os.path.join(args.output_dir, "eval_results.txt")): os.remove(os.path.join(args.output_dir, "eval_results.txt")) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if not args.num_train_epochs: args.num_train_epochs = default_params[task_name]["num_train_epochs"] if not args.learning_rate: args.learning_rate = default_params[task_name]["learning_rate"] if not args.train_batch_size: args.train_batch_size = default_params[task_name]["train_batch_size"] if not args.max_seq_length: args.max_seq_len = default_params[task_name]["max_seq_length"] # print(task_name in default_params, args.num_train_epochs, args.max_seq_length) if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.model, do_lower_case=args.do_lower_case) if not args.do_eval: if args.aug_train == 'none': train_examples = processor.get_train_examples( args.data_dir, args.train_name) else: train_examples = processor.get_aug_examples( args.data_dir, args.aug_train) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps num_train_optimization_steps = int( len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, output_mode) # train_features = old_convert_examples_to_features(train_examples, label_list, # args.max_seq_length, tokenizer, output_mode) train_data, _ = get_tensor_data(output_mode, train_features) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) eval_examples = processor.get_dev_examples(args.data_dir, args.val_name) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) # eval_features = old_convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) eval_data, eval_labels = get_tensor_data(output_mode, eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) # load config file from here quant_config = BertConfig.from_json_file("config/new_example_config.json") # change config if specified in command if "quant_group_number" in quant_config.__dict__: quant_config.__dict__["quant_group_number"] = args.quant_group_number for item in quant_config.__dict__: if "layer_bits" in item: for b_item in quant_config.__dict__[item]: quant_config.__dict__[item][b_item] = args.weight_bit elif "embed_bits" in item: for b_item in quant_config.__dict__[item]: quant_config.__dict__[item][b_item] = args.weight_bit elif "activation_bits" in item: quant_config.__dict__[item] = args.activation_bit model = QBertForSequenceClassification.from_pretrained( args.model, num_labels=num_labels, quant_config=quant_config) model.to(device) if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) model.eval() result = do_eval(model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: model = torch.nn.DataParallel(model) param_optimizer = list(model.named_parameters()) size = 0 for n, p in model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) scheduler = NewWarmupLinearSchedule( optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps), t_total=num_train_optimization_steps) # optimizer = BertAdam( # optimizer_grouped_parameters, # lr=args.learning_rate, # warmup=args.warmup_proportion, # t_total=num_train_optimization_steps) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) # targets_prob = torch.nn.functional.softmax(targets, dim=-1) targets_prob = targets return (-targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 best_dev_acc = -1 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. student_logits, student_atts, student_reps = model( input_ids, segment_ids, input_mask, is_student=True) # if output_mode == "classification": # loss_fct = CrossEntropyLoss() # cls_loss = loss_fct(student_logits.view(-1, num_labels), label_ids.view(-1)) if output_mode == "classification": # loss_fct = CrossEntropyLoss() cls_loss = soft_cross_entropy(student_logits, label_ids) elif output_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss tr_cls_loss += cls_loss.item() if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() model.zero_grad() global_step += 1 logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Batch size = %d", args.eval_batch_size) model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} result = do_eval(model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss result_to_file(result, output_eval_file) save_model = True if task_name in acc_tasks and result['acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result['corr'] > best_dev_acc: best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result['mcc'] > best_dev_acc: best_dev_acc = result['mcc'] save_model = True if save_model: logger.info("***** Save model *****") model_to_save = model.module if hasattr(model, 'module') else model model_name = WEIGHTS_NAME output_model_file = os.path.join(args.output_dir, model_name) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Test mnli-mm # if task_name == "mnli": # task_name = "mnli" # processor = processors[task_name]() # if not os.path.exists(args.output_dir + '-MM'): # os.makedirs(args.output_dir + '-MM') # eval_examples = processor.get_dev_examples(args.data_dir) # eval_features = convert_examples_to_features( # eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) # eval_data, eval_labels = get_tensor_data(output_mode, eval_features) # logger.info("***** Running mm evaluation *****") # logger.info(" Num examples = %d", len(eval_examples)) # logger.info(" Batch size = %d", args.eval_batch_size) # eval_sampler = SequentialSampler(eval_data) # eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, # batch_size=args.eval_batch_size) # result = do_eval(model, task_name, eval_dataloader, # device, output_mode, eval_labels, num_labels) # result['global_step'] = global_step # tmp_output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt") # result_to_file(result, tmp_output_eval_file) # task_name = 'mnli' # if oncloud: # logging.info(mox.file.list_directory(args.output_dir, recursive=True)) # logging.info(mox.file.list_directory('.', recursive=True)) # mox.file.copy_parallel(args.output_dir, args.data_url) # mox.file.copy_parallel('.', args.data_url) model.train()
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default="data/MNLI", type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--teacher_model", default="pretrained/checkpoint-31280/", type=str, help="The teacher model dir.") parser.add_argument("--student_model", default="pretrained/generalbert", type=str, help="The student model dir.") parser.add_argument("--task_name", default="MNLI", type=str, help="The name of the task to train.") parser.add_argument( "--output_dir", default="output", type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=384, type=int, help="Total batch size for training.") parser.add_argument("--eval_batch_size", default=128, type=int, help="Total batch size for eval.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=5.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) # added arguments parser.add_argument('--aug_train', action='store_true') parser.add_argument('--eval_step', type=float, default=0.1) parser.add_argument('--pred_distill', action='store_true') parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) # intermediate distillation default parameters default_params = { "cola": { "num_train_epochs": 50, "max_seq_length": 64 }, "mnli": { "num_train_epochs": 5, "max_seq_length": 128 }, "mrpc": { "num_train_epochs": 20, "max_seq_length": 128 }, "sst-2": { "num_train_epochs": 10, "max_seq_length": 64 }, "sts-b": { "num_train_epochs": 20, "max_seq_length": 128 }, "qqp": { "num_train_epochs": 5, "max_seq_length": 128 }, "qnli": { "num_train_epochs": 10, "max_seq_length": 128 }, "rte": { "num_train_epochs": 20, "max_seq_length": 128 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) tb = SummaryWriter("./runs") # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() if task_name in default_params: args.max_seq_len = default_params[task_name]["max_seq_length"] if not args.pred_distill and not args.do_eval: if task_name in default_params: args.num_train_epoch = default_params[task_name][ "num_train_epochs"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) if not args.do_eval: if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps train_data, _ = get_tensor_data(args, task_name, tokenizer, False, args.aug_train) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) num_train_optimization_steps = int( len(train_dataloader) / args.gradient_accumulation_steps) * args.num_train_epochs eval_data, eval_labels = get_tensor_data(args, task_name, tokenizer, True, False) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) if not args.do_eval: teacher_model = TinyBertForSequenceClassification.from_pretrained( args.teacher_model, num_labels=num_labels) teacher_model.to(device) student_model = TinyBertForSequenceClassification.from_pretrained( args.student_model, num_labels=num_labels) student_model.to(device) if args.do_eval: logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) else: logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_data)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' if not args.pred_distill: schedule = 'none' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * student_likelihood).mean() # Train and evaluate global_step = 0 best_dev_acc = 0.0 output_eval_file = os.path.join(args.output_dir, "eval_results.txt") for epoch_ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. student_model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch if input_ids.size()[0] != args.train_batch_size: continue att_loss = 0. rep_loss = 0. cls_loss = 0. student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask, is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) if not args.pred_distill: teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [ teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num) ] for student_att, teacher_att in zip( student_atts, new_teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss new_teacher_reps = [ teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) ] new_student_reps = student_reps for student_rep, teacher_rep in zip( new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss = rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() else: if output_mode == "classification": cls_loss = soft_cross_entropy( student_logits / args.temperature, teacher_logits / args.temperature) elif output_mode == "regression": loss_mse = MSELoss() cls_loss = loss_mse(student_logits.view(-1), label_ids.view(-1)) loss = cls_loss tr_cls_loss += cls_loss.item() if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() tb.add_scalar("loss", loss.item(), global_step) tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % int( args.eval_step * num_train_optimization_steps) == 0: logger.info("***** Running evaluation *****") logger.info(" Epoch = {} iter {} step".format( epoch_, global_step)) logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = {} if args.pred_distill: result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss result_to_file(result, output_eval_file) if not args.pred_distill: save_model = True else: save_model = False if task_name in acc_tasks and result[ 'acc'] > best_dev_acc: best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result[ 'corr'] > best_dev_acc: best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result[ 'mcc'] > best_dev_acc: best_dev_acc = result['mcc'] save_model = True if save_model: logger.info("***** Save model *****") model_to_save = student_model.module if hasattr( student_model, 'module') else student_model model_name = WEIGHTS_NAME # if not args.pred_distill: # model_name = "step_{}_{}".format(global_step, WEIGHTS_NAME) output_model_file = os.path.join( args.output_dir, model_name) output_config_file = os.path.join( args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Test mnli-mm if args.pred_distill and task_name == "mnli": task_name = "mnli-mm" if not os.path.exists(args.output_dir + '-MM'): os.makedirs(args.output_dir + '-MM') eval_data, eval_labels = get_tensor_data( args, task_name, tokenizer, True, False) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader( eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) logger.info("***** Running mm evaluation *****") logger.info(" Num examples = %d", len(eval_data)) logger.info(" Batch size = %d", args.eval_batch_size) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step tmp_output_eval_file = os.path.join( args.output_dir + '-MM', "eval_results.txt") result_to_file(result, tmp_output_eval_file) task_name = 'mnli' student_model.train()
def main(): parser = ArgumentParser() parser.add_argument( '--pregenerated_data', type=str, required=True, default='/nas/hebin/data/english-exp/books_wiki_tokens_ngrams') parser.add_argument('--s3_output_dir', type=str, default='huawei_yun') parser.add_argument('--student_model', type=str, default='8layer_bert', required=True) parser.add_argument('--teacher_model', type=str, default='electra_base') parser.add_argument('--cache_dir', type=str, default='/cache', help='') parser.add_argument("--epochs", type=int, default=2, help="Number of epochs to train for") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--train_batch_size", default=16, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=1e-4, type=float, help="The initial learning rate for Adam.") parser.add_argument("--max_seq_length", type=int, default=512) parser.add_argument("--do_lower_case", action="store_true") parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--scratch', action='store_true', help="Whether to train from scratch") parser.add_argument( "--reduce_memory", action="store_true", help= "Store training data as on-disc memmaps to massively reduce memory usage" ) parser.add_argument('--debug', action='store_true', help="Whether to debug") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument( "--fp16_opt_level", type=str, default="O1", help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument("--already_trained_epoch", default=0, type=int) parser.add_argument( "--masked_lm_prob", type=float, default=0.0, help="Probability of masking each token for the LM task") parser.add_argument( "--max_predictions_per_seq", type=int, default=77, help="Maximum number of tokens to mask in each sequence") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") parser.add_argument("--warmup_steps", default=10000, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--num_workers", type=int, default=4, help="num_workers.") parser.add_argument("--continue_index", type=int, default=0, help="") parser.add_argument("--threads", type=int, default=27, help="Number of threads to preprocess input data") # Search space for sub_bart architecture parser.add_argument('--layer_num_space', nargs='+', type=int, default=[1, 8]) parser.add_argument('--hidden_size_space', nargs='+', type=int, default=[128, 768]) parser.add_argument('--qkv_size_space', nargs='+', type=int, default=[180, 768]) parser.add_argument('--intermediate_size_space', nargs='+', type=int, default=[128, 3072]) parser.add_argument('--head_num_space', nargs='+', type=int, default=[1, 12]) parser.add_argument('--sample_times_per_batch', type=int, default=1) parser.add_argument('--further_train', action='store_true') parser.add_argument('--mlm_loss', action='store_true') # Argument for Huawei yun parser.add_argument('--data_url', type=str, default='', help='s3 url') parser.add_argument("--train_url", type=str, default="", help="s3 url") args = parser.parse_args() assert (torch.cuda.is_available()) device_count = torch.cuda.device_count() args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) # Call the init process # init_method = 'tcp://' init_method = '' master_ip = os.getenv('MASTER_ADDR', 'localhost') master_port = os.getenv('MASTER_PORT', '6000') init_method += master_ip + ':' + master_port # Manually set the device ids. # if device_count > 0: # args.local_rank = args.rank % device_count torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) print('device_id: %s' % args.local_rank) print('device_count: %s, rank: %s, world_size: %s' % (device_count, args.rank, args.world_size)) print(init_method) torch.distributed.init_process_group(backend='nccl', world_size=args.world_size, rank=args.rank, init_method=init_method) LOCAL_DIR = args.cache_dir if oncloud: assert mox.file.exists(LOCAL_DIR) if args.local_rank == 0 and oncloud: logging.info( mox.file.list_directory(args.pregenerated_data, recursive=True)) logging.info( mox.file.list_directory(args.student_model, recursive=True)) local_save_dir = os.path.join(LOCAL_DIR, 'output', 'superbert', 'checkpoints') local_tsbd_dir = os.path.join(LOCAL_DIR, 'output', 'superbert', 'tensorboard') save_name = '_'.join([ 'superbert', 'epoch', str(args.epochs), 'lr', str(args.learning_rate), 'bsz', str(args.train_batch_size), 'grad_accu', str(args.gradient_accumulation_steps), str(args.max_seq_length), 'gpu', str(args.world_size), ]) bash_save_dir = os.path.join(local_save_dir, save_name) bash_tsbd_dir = os.path.join(local_tsbd_dir, save_name) if args.local_rank == 0: if not os.path.exists(bash_save_dir): os.makedirs(bash_save_dir) logger.info(bash_save_dir + ' created!') if not os.path.exists(bash_tsbd_dir): os.makedirs(bash_tsbd_dir) logger.info(bash_tsbd_dir + ' created!') local_data_dir_tmp = '/cache/data/tmp/' local_data_dir = local_data_dir_tmp + save_name if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) args.tokenizer = BertTokenizer.from_pretrained( args.student_model, do_lower_case=args.do_lower_case) args.vocab_list = list(args.tokenizer.vocab.keys()) config = BertConfig.from_pretrained( os.path.join(args.student_model, CONFIG_NAME)) logger.info("Model config {}".format(config)) if args.further_train: if args.mlm_loss: student_model = SuperBertForPreTraining.from_pretrained( args.student_model, config) else: student_model = SuperTinyBertForPreTraining.from_pretrained( args.student_model, config) else: if args.mlm_loss: student_model = SuperBertForPreTraining.from_scratch( args.student_model, config) else: student_model = SuperTinyBertForPreTraining.from_scratch( args.student_model, config) student_model.to(device) if not args.mlm_loss: teacher_model = BertModel.from_pretrained(args.teacher_model) teacher_model.to(device) # build arch space min_hidden_size, max_hidden_size = args.hidden_size_space min_ffn_size, max_ffn_size = args.intermediate_size_space min_qkv_size, max_qkv_size = args.qkv_size_space min_head_num, max_head_num = args.head_num_space hidden_step = 4 ffn_step = 4 qkv_step = 12 head_step = 1 number_hidden_step = int((max_hidden_size - min_hidden_size) / hidden_step) number_ffn_step = int((max_ffn_size - min_ffn_size) / ffn_step) number_qkv_step = int((max_qkv_size - min_qkv_size) / qkv_step) number_head_step = int((max_head_num - min_head_num) / head_step) layer_numbers = list( range(args.layer_num_space[0], args.layer_num_space[1] + 1)) hidden_sizes = [ i * hidden_step + min_hidden_size for i in range(number_hidden_step + 1) ] ffn_sizes = [ i * ffn_step + min_ffn_size for i in range(number_ffn_step + 1) ] qkv_sizes = [ i * qkv_step + min_qkv_size for i in range(number_qkv_step + 1) ] head_numbers = [ i * head_step + min_head_num for i in range(number_head_step + 1) ] ###### if args.local_rank == 0: tb_writer = SummaryWriter(bash_tsbd_dir) global_step = 0 step = 0 tr_loss, tr_rep_loss, tr_att_loss = 0.0, 0.0, 0.0 logging_loss, rep_logging_loss, att_logging_loss = 0.0, 0.0, 0.0 end_time, start_time = 0, 0 submodel_config = dict() if args.further_train: submodel_config['sample_layer_num'] = config.num_hidden_layers submodel_config['sample_hidden_size'] = config.hidden_size submodel_config[ 'sample_intermediate_sizes'] = config.num_hidden_layers * [ config.intermediate_size ] submodel_config[ 'sample_num_attention_heads'] = config.num_hidden_layers * [ config.num_attention_heads ] submodel_config['sample_qkv_sizes'] = config.num_hidden_layers * [ config.qkv_size ] for epoch in range(args.epochs): if epoch < args.continue_index: args.warmup_steps = 0 continue args.local_data_dir = os.path.join(local_data_dir, str(epoch)) if args.local_rank == 0: os.makedirs(args.local_data_dir) while 1: if os.path.exists(args.local_data_dir): epoch_dataset = load_doc_tokens_ngrams(args) break if args.local_rank == 0 and oncloud: logging.info('Dataset in epoch %s', epoch) logging.info( mox.file.list_directory(args.local_data_dir, recursive=True)) train_sampler = DistributedSampler(epoch_dataset, num_replicas=1, rank=0) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) step_in_each_epoch = len( train_dataloader) // args.gradient_accumulation_steps num_train_optimization_steps = step_in_each_epoch * args.epochs logging.info("***** Running training *****") logging.info(" Num examples = %d", len(epoch_dataset) * args.world_size) logger.info(" Num Epochs = %d", args.epochs) logging.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * args.world_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logging.info(" Num steps = %d", num_train_optimization_steps) if epoch == args.continue_index: # Prepare optimizer param_optimizer = list(student_model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.01 }, { 'params': [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] warm_up_ratio = args.warmup_steps / num_train_optimization_steps print('warm_up_ratio: {}'.format(warm_up_ratio)) optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, e=args.adam_epsilon, schedule='warmup_linear', t_total=num_train_optimization_steps, warmup=warm_up_ratio) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex" " to use fp16 training.") student_model, optimizer = amp.initialize( student_model, optimizer, opt_level=args.fp16_opt_level, min_loss_scale=1) # # apex student_model = DDP( student_model, message_size=10000000, gradient_predivide_factor=torch.distributed.get_world_size(), delay_allreduce=True) if not args.mlm_loss: teacher_model = DDP(teacher_model, message_size=10000000, gradient_predivide_factor=torch. distributed.get_world_size(), delay_allreduce=True) teacher_model.eval() logger.info('apex data paralleled!') from torch.nn import MSELoss loss_mse = MSELoss() student_model.train() for step_, batch in enumerate(train_dataloader): step += 1 batch = tuple(t.to(device) for t in batch) input_ids, input_masks, lm_label_ids = batch if not args.mlm_loss: teacher_last_rep, teacher_last_att = teacher_model( input_ids, input_masks) teacher_last_att = torch.where( teacher_last_att <= -1e2, torch.zeros_like(teacher_last_att).to(device), teacher_last_att) teacher_last_rep.detach() teacher_last_att.detach() for sample_idx in range(args.sample_times_per_batch): att_loss = 0. rep_loss = 0. rand_seed = int(global_step * args.world_size + sample_idx) # + args.rank % args.world_size) if not args.mlm_loss: if not args.further_train: submodel_config = sample_arch_4_kd( layer_numbers, hidden_sizes, ffn_sizes, qkv_sizes, reset_rand_seed=True, rand_seed=rand_seed) # knowledge distillation student_last_rep, student_last_att = student_model( input_ids, submodel_config, attention_mask=input_masks) student_last_att = torch.where( student_last_att <= -1e2, torch.zeros_like(student_last_att).to(device), student_last_att) att_loss += loss_mse(student_last_att, teacher_last_att) rep_loss += loss_mse(student_last_rep, teacher_last_rep) loss = att_loss + rep_loss if args.gradient_accumulation_steps > 1: rep_loss = rep_loss / args.gradient_accumulation_steps att_loss = att_loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps tr_rep_loss += rep_loss.item() tr_att_loss += att_loss.item() else: if not args.further_train: submodel_config = sample_arch_4_mlm( layer_numbers, hidden_sizes, ffn_sizes, head_numbers, reset_rand_seed=True, rand_seed=rand_seed) loss = student_model(input_ids, submodel_config, attention_mask=input_masks, masked_lm_labels=lm_label_ids) tr_loss += loss.item() if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(retain_graph=True) else: loss.backward(retain_graph=True) if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(student_model.parameters(), args.max_grad_norm) optimizer.step() optimizer.zero_grad() global_step += 1 if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0 \ and args.local_rank < 2 or global_step < 100: end_time = time.time() if not args.mlm_loss: logger.info( 'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; ' 'rep_loss is %s; att_loss is %s; (%.2f sec)' % (epoch, global_step + 1, step_in_each_epoch, optimizer.get_lr()[0], loss.item() * args.gradient_accumulation_steps, rep_loss.item() * args.gradient_accumulation_steps, att_loss.item() * args.gradient_accumulation_steps, end_time - start_time)) else: logger.info( 'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; ' ' (%.2f sec)' % (epoch, global_step + 1, step_in_each_epoch, optimizer.get_lr()[0], loss.item() * args.gradient_accumulation_steps, end_time - start_time)) start_time = time.time() if args.logging_steps > 0 and global_step % args.logging_steps == 0 and args.local_rank == 0: tb_writer.add_scalar("lr", optimizer.get_lr()[0], global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) if not args.mlm_loss: tb_writer.add_scalar("rep_loss", (tr_rep_loss - rep_logging_loss) / args.logging_steps, global_step) tb_writer.add_scalar("att_loss", (tr_att_loss - att_logging_loss) / args.logging_steps, global_step) rep_logging_loss = tr_rep_loss att_logging_loss = tr_att_loss logging_loss = tr_loss # Save a trained model if args.rank == 0: saving_path = bash_save_dir saving_path = Path(os.path.join(saving_path, "epoch_" + str(epoch))) if saving_path.is_dir() and list(saving_path.iterdir()): logging.warning( f"Output directory ({ saving_path }) already exists and is not empty!" ) saving_path.mkdir(parents=True, exist_ok=True) logging.info("** ** * Saving fine-tuned model ** ** * ") model_to_save = student_model.module if hasattr(student_model, 'module')\ else student_model # Only save the model it-self output_model_file = os.path.join(saving_path, WEIGHTS_NAME) output_config_file = os.path.join(saving_path, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) args.tokenizer.save_vocabulary(saving_path) torch.save(optimizer.state_dict(), os.path.join(saving_path, "optimizer.pt")) logger.info("Saving optimizer and scheduler states to %s", saving_path) # debug if oncloud: local_output_dir = os.path.join(LOCAL_DIR, 'output') logger.info( mox.file.list_directory(local_output_dir, recursive=True)) logger.info('s3_output_dir: ' + args.s3_output_dir) mox.file.copy_parallel(local_output_dir, args.s3_output_dir) if args.local_rank == 0: tb_writer.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument("--job_id", default='tmp', type=str, help='Jobid to save training logs') parser.add_argument("--data_dir",default=None,type=str,help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--teacher_model",default=None,type=str,help="The teacher model dir.") parser.add_argument("--student_model",default=None,type=str,help="The student model dir.") parser.add_argument("--output_dir",default='output',type=str,help="The output directory where the model predictions and checkpoints will be written.") # default params for SQuAD parser.add_argument('--version_2_with_negative', action='store_true') parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument("--max_query_length", default=64, type=int, help="The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.") parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument('--eval_step', type=int, default=200) parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--do_eval',default = 0,type=int) # distillation params parser.add_argument('--aug_train', action='store_true', help="Whether using data augmentation or not") parser.add_argument('--kd_type', default='no_kd', choices=['no_kd', 'two_stage', 'logit_kd', 'joint_kd'], help="choose one of the kd type") parser.add_argument('--distill_logit', action='store_true', help="Whether using distillation over logits or not") parser.add_argument('--distill_rep_attn', action='store_true', help="Whether using distillation over reps and attns or not") parser.add_argument('--temperature', type=float, default=1.) # quantization params parser.add_argument("--weight_bits", default=32, type=int, help="number of bits for weight") parser.add_argument("--weight_quant_method", default='twn', type=str, choices=['twn', 'bwn', 'uniform', 'laq'], help="weight quantization methods, can be bwn, twn, laq") parser.add_argument("--input_bits", default=32, type=int, help="number of bits for activation") parser.add_argument("--input_quant_method", default='uniform', type=str, choices=['uniform', 'lsq'], help="weight quantization methods, can be bwn, twn, or symmetric quantization for default") parser.add_argument('--learnable_scaling', action='store_true', default=True) parser.add_argument("--ACT2FN", default='gelu', type=str, help='activation fn for ffn-mid. A8 uses uq + gelu; A4 uses lsq + relu.') # training config parser.add_argument('--sym_quant_ffn_attn', action='store_true', help='whether use sym quant for attn score and ffn after act') # default asym parser.add_argument('--sym_quant_qkvo', action='store_true', default=True, help='whether use asym quant for Q/K/V and others.') # default sym # layerwise quantization config parser.add_argument('--clip_init_file', default='threshold_std.pkl', help='files to restore init clip values.') parser.add_argument('--clip_init_val', default=2.5, type=float, help='init value of clip_vals, default to (-2.5, +2.5).') parser.add_argument('--clip_lr', default=1e-4, type=float, help='Use a seperate lr for clip_vals / stepsize') parser.add_argument('--clip_wd', default=0.0, type=float, help='weight decay for clip_vals / stepsize') # layerwise quantization config parser.add_argument('--embed_layerwise', default=False, type=lambda x: bool(int(x))) parser.add_argument('--weight_layerwise', default=True, type=lambda x: bool(int(x))) parser.add_argument('--input_layerwise', default=True, type=lambda x: bool(int(x))) ### spliting parser.add_argument('--split', action='store_true', help='whether to conduct tws spliting. NOTE this is only for training binarybert') parser.add_argument('--is_binarybert', action='store_true', help='whether to use binarybert modelling.') args = parser.parse_args() log_dir = os.path.join(args.output_dir, 'record_%s.log' % args.job_id) init_logging(log_dir) print_args(vars(args)) # Prepare devices device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=True) config = BertConfig.from_pretrained(args.teacher_model) config.num_labels = 2 student_config = copy.deepcopy(config) student_config.weight_bits = args.weight_bits student_config.input_bits = args.input_bits student_config.weight_quant_method = args.weight_quant_method student_config.input_quant_method = args.input_quant_method student_config.clip_init_val = args.clip_init_val student_config.learnable_scaling = args.learnable_scaling student_config.sym_quant_qkvo = args.sym_quant_qkvo student_config.sym_quant_ffn_attn = args.sym_quant_ffn_attn student_config.embed_layerwise = args.embed_layerwise student_config.weight_layerwise = args.weight_layerwise student_config.input_layerwise = args.input_layerwise student_config.hidden_act = args.ACT2FN logging.info("***** Training data *****") input_file = 'train-v2.0.json' if args.version_2_with_negative else 'train-v1.1.json' input_file = os.path.join(args.data_dir,input_file) if os.path.exists(input_file+'.features.pkl'): logging.info(" loading from cache %s", input_file+'.features.pkl') train_features = pickle.load(open(input_file+'.features.pkl', 'rb')) else: _, train_examples = read_squad_examples(input_file=input_file, is_training=True, version_2_with_negative=args.version_2_with_negative) train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) pickle.dump(train_features, open(input_file+'.features.pkl','wb')) args.batch_size = args.batch_size // args.gradient_accumulation_steps num_train_optimization_steps = int( len(train_features) / args.batch_size / args.gradient_accumulation_steps) * args.num_train_epochs logging.info(" Num examples = %d", len(train_features)) logging.info(" Num total steps = %d", num_train_optimization_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) logging.info("***** Evaluation data *****") input_file = 'dev-v2.0.json' if args.version_2_with_negative else 'dev-v1.1.json' args.dev_file = os.path.join(args.data_dir,input_file) dev_dataset, eval_examples = read_squad_examples( input_file=args.dev_file, is_training=False, version_2_with_negative=args.version_2_with_negative) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logging.info(" Num examples = %d", len(eval_features)) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) if not args.do_eval: from transformer.modeling_dynabert import BertForQuestionAnswering teacher_model = BertForQuestionAnswering.from_pretrained(args.teacher_model, config = config) teacher_model.to(device) if n_gpu > 1: teacher_model = torch.nn.DataParallel(teacher_model) if args.split: # rename the checkpoint to restore split_model_dir = os.path.join(args.output_dir,'binary_model_init') if not os.path.exists(split_model_dir): os.mkdir(split_model_dir) # copy the json file, avoid over-writing source_model_dir = os.path.join(args.student_model, CONFIG_NAME) target_model_dir = os.path.join(split_model_dir, CONFIG_NAME) os.system('cp -v %s %s' % (source_model_dir, target_model_dir)) # create the split model ckpt source_model_dir = os.path.join(args.student_model, WEIGHTS_NAME) target_model_dir = os.path.join(split_model_dir, WEIGHTS_NAME) target_model_dir = tws_split(source_model_dir, target_model_dir) args.student_model = split_model_dir # over-write student_model dir print("transformed binary model stored at: {}".format(target_model_dir)) if args.is_binarybert: from transformer.modeling_dynabert_binary import BertForQuestionAnswering student_model = BertForQuestionAnswering.from_pretrained(args.student_model, config=student_config) else: from transformer.modeling_dynabert_quant import BertForQuestionAnswering student_model = BertForQuestionAnswering.from_pretrained(args.student_model, config=student_config) student_model.to(device) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) learner = KDLearner(args, device, student_model, teacher_model,num_train_optimization_steps) if args.do_eval: """ evaluation """ learner.eval(student_model, eval_dataloader, eval_features, eval_examples, dev_dataset) return 0 """ perform training """ if args.kd_type == 'joint_kd': learner.args.distill_logit = True learner.args.distill_rep_attn = True learner.build() learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset) elif args.kd_type == 'logit_kd': # only perform the logits kd learner.args.distill_logit = True learner.args.distill_rep_attn = False learner.build(lr=args.learning_rate) learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset) elif args.kd_type == 'two_stage': # stage 1: intermediate layer distillation learner.args.distill_logit = False learner.args.distill_rep_attn = True learner.build(lr=2.5*args.learning_rate) learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset) # stage 2: prediction layer distillation learner.student_model.load_state_dict(torch.load(os.path.join(learner.output_dir,WEIGHTS_NAME))) learner.args.distill_logit = True learner.args.distill_rep_attn = False learner.build(lr=args.learning_rate) # prepare the optimizer again. learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset) else: assert args.kd_type == 'no_kd' # NO kd training, vanilla cross entropy with hard label learner.build(lr=args.learning_rate) # prepare the optimizer again. learner.train(train_dataloader, eval_dataloader, eval_features, eval_examples, dev_dataset) del learner return 0
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default=None, type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--student_model", default=None, type=str, help="The student model dir.") parser.add_argument("--task_name", default=None, type=str, required=True, help="The name of the task to train.") parser.add_argument( "--output_dir", default=None, type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--cache_dir", default="", type=str, help= "Where do you want to store the pre-trained models downloaded from s3") parser.add_argument( "--max_seq_length", default=128, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the test set.") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--data_url', type=str, default="") parser.add_argument('--temperature', type=float, default=1.) args = parser.parse_args() logger.info('The args: {}'.format(args)) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor, "wnli": WnliProcessor } output_modes = { "cola": "classification", "mnli": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification", "wnli": "classification" } # intermediate distillation default parameters default_params = { "cola": { "num_train_epochs": 50, "max_seq_length": 64, 'train_batch_size': 32 }, "mnli": { "num_train_epochs": 5, "max_seq_length": 128, 'train_batch_size': 64 }, "mrpc": { "num_train_epochs": 20, "max_seq_length": 128, 'train_batch_size': 32 }, "sst-2": { "num_train_epochs": 10, "max_seq_length": 64, 'train_batch_size': 32 }, "sts-b": { "num_train_epochs": 20, "max_seq_length": 128, 'train_batch_size': 32 }, "qqp": { "num_train_epochs": 5, "max_seq_length": 128, 'train_batch_size': 64 }, "qnli": { "num_train_epochs": 10, "max_seq_length": 128, 'train_batch_size': 64 }, "rte": { "num_train_epochs": 20, "max_seq_length": 128, 'train_batch_size': 32 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) task_name = args.task_name.lower() # Prepare task settings if os.path.exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if task_name in default_params: args.max_seq_length = default_params[task_name]["max_seq_length"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) student_model = PrunBertForSequenceClassification.from_pretrained( args.student_model, num_labels=num_labels) student_model.to(device) student_model.eval() if args.do_eval: eval_dataloader, num_eval_examples, eval_labels = build_dataloader( 'dev', args, processor, label_list, tokenizer, output_mode) logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", num_eval_examples) logger.info(" Batch size = %d", args.eval_batch_size) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) logger.info("***** Eval results *****") output_eval_file = os.path.join(args.output_dir, "eval_results.txt") result_to_file(result, output_eval_file) if task_name == "mnli": task_name = "mnli-mm" processor = processors[task_name]() eval_dataloader, num_eval_examples, eval_labels = build_dataloader( 'dev', args, processor, label_list, tokenizer, output_mode) logger.info("***** Running mm evaluation *****") logger.info(" Num examples = %d", num_eval_examples) logger.info(" Batch size = %d", args.eval_batch_size) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) output_eval_file = os.path.join(args.output_dir, "eval_results-mm.txt") result_to_file(result, output_eval_file) task_name = "mnli" if args.do_predict: processor = processors[task_name]() test_dataloader, num_test_examples, _ = build_dataloader( 'test', args, processor, label_list, tokenizer, output_mode) logger.info("***** Running prediction *****") logger.info(" Num examples = %d", num_test_examples) logger.info(" Batch size = %d", args.eval_batch_size) predictions = do_predict(student_model, task_name, test_dataloader, device, output_mode, num_labels) label_list = processor.get_labels() write_predictions(predictions, args, task_name, output_mode, label_list) if task_name == "mnli": task_name = "mnli-mm" processor = processors[task_name]() test_dataloader, num_test_examples, _ = build_dataloader( 'test', args, processor, label_list, tokenizer, output_mode) logger.info("***** Running mm prediction *****") logger.info(" Num examples = %d", num_test_examples) logger.info(" Batch size = %d", args.eval_batch_size) predictions = do_predict(student_model, task_name, test_dataloader, device, output_mode, num_labels) write_predictions(predictions, args, task_name, output_mode, label_list) task_name = 'mnli'
def main(): parser = ArgumentParser() parser.add_argument('--train_corpus', type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--bert_model", type=str, required=True) parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--max_seq_len", type=int, default=128) parser.add_argument( "--reduce_memory", action="store_true", help= "Reduce memory usage for large datasets by keeping data on disc rather than in memory" ) parser.add_argument("--num_workers", type=int, default=1, help="The number of workers to use to write the files") # add 1. for huawei yun. parser.add_argument("--data_url", type=str, default="", help="s3 url") parser.add_argument("--train_url", type=str, default="", help="s3 url") parser.add_argument("--init_method", default='', type=str) args = parser.parse_args() # add 2. for huawei yun. if oncloud: os.environ['DLS_LOCAL_CACHE_PATH'] = "/cache" local_data_dir = os.environ['DLS_LOCAL_CACHE_PATH'] assert mox.file.exists(local_data_dir) logging.info("local disk: " + local_data_dir) logging.info("copy data from s3 to local") logging.info(mox.file.list_directory(args.data_url, recursive=True)) mox.file.copy_parallel(args.data_url, local_data_dir) logging.info("copy finish...........") args.train_corpus = Path( os.path.join(local_data_dir, args.train_corpus)) args.bert_model = os.path.join(local_data_dir, args.bert_model) args.train_url = os.path.join(args.train_url, args.output_dir) args.output_dir = Path(os.path.join(local_data_dir, args.output_dir)) if args.num_workers > 1 and args.reduce_memory: raise ValueError("Cannot use multiple workers while reducing memory") tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) doc_num = 0 with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: with args.train_corpus.open() as f: doc = [] for line in tqdm(f, desc="Loading Dataset", unit=" lines"): line = line.strip() if line == "": docs.add_document(doc) doc = [] doc_num += 1 if doc_num % 100 == 0: logger.info('loaded {} docs!'.format(doc_num)) else: tokens = tokenizer.tokenize(line) doc.append(tokens) if doc: docs.add_document( doc ) # If the last doc didn't end on a newline, make sure it still gets added if len(docs) <= 1: exit( "ERROR: No document breaks were found in the input file! These are necessary to allow the script to " "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " "indicate breaks between documents in your input file. If your dataset does not contain multiple " "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " "sections or paragraphs.") args.output_dir.mkdir(exist_ok=True) file_num = 28 fouts = [] for i in range(file_num): file_name = os.path.join( str(args.output_dir), 'train_doc_tokens_ngrams_{}.json'.format(i)) fouts.append(open(file_name, 'w')) cnt = 0 for doc_idx in trange(len(docs), desc="Document"): document = docs[doc_idx] i = 0 tokens = [] while i < len(document): segment = document[i] if len(tokens) + len(segment) > args.max_seq_len: instance = {"tokens": tokens} file_idx = cnt % file_num fouts[file_idx].write(json.dumps(instance) + '\n') cnt += 1 if cnt % 100000 == 0: logger.info('loaded {} examples!'.format(cnt)) if cnt <= 10: logger.info('instance: {}'.format(instance)) tokens = [] tokens += segment else: tokens += segment i += 1 if tokens: instance = {"tokens": tokens} file_idx = cnt % file_num fouts[file_idx].write(json.dumps(instance) + '\n') for fout in fouts: fout.close() if oncloud: logging.info( mox.file.list_directory(str(args.output_dir), recursive=True)) mox.file.copy_parallel(str(args.output_dir), args.train_url)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--job_id", default='tmp', type=str, help='jobid to save training logs') parser.add_argument("--data_dir", default=None, type=str, help="The root dir of glue data") parser.add_argument("--teacher_model", default='', type=str, help="The teacher model dir.") parser.add_argument("--student_model", default='', type=str, help="The student model dir.") parser.add_argument("--task_name", default=None, type=str, help="The name of the glue task to train.") parser.add_argument( "--output_dir", default='output', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument( "--max_seq_length", default=None, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument("--batch_size", default=None, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=0.01, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=None, type=int, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--do_eval", action='store_true') parser.add_argument('--eval_step', type=int, default=100) # distillation params parser.add_argument('--aug_train', action='store_true', help="Whether using data augmentation or not") parser.add_argument('--kd_type', default='no_kd', choices=['no_kd', 'two_stage', 'logit_kd', 'joint_kd'], help="choose one of the kd type") parser.add_argument('--distill_logit', action='store_true', help="Whether using distillation over logits or not") parser.add_argument( '--distill_rep_attn', action='store_true', help="Whether using distillation over reps and attns or not") parser.add_argument('--temperature', type=float, default=1.) # quantization params parser.add_argument("--weight_bits", default=32, type=int, help="number of bits for weight") parser.add_argument( "--weight_quant_method", default='twn', type=str, choices=['twn', 'bwn', 'uniform', 'laq'], help="weight quantization methods, can be bwn, twn, laq") parser.add_argument("--input_bits", default=32, type=int, help="number of bits for activation") parser.add_argument( "--input_quant_method", default='uniform', type=str, choices=['uniform', 'lsq'], help= "weight quantization methods, can be bwn, twn, or symmetric quantization for default" ) parser.add_argument('--learnable_scaling', action='store_true', default=True) parser.add_argument( "--ACT2FN", default='gelu', type=str, help='activation fn for ffn-mid. A8 uses uq + gelu; A4 uses lsq + relu.' ) # training config parser.add_argument( '--sym_quant_ffn_attn', action='store_true', help='whether use sym quant for attn score and ffn after act' ) # default asym parser.add_argument( '--sym_quant_qkvo', action='store_true', default=True, help='whether use asym quant for Q/K/V and others.') # default sym # hypers clipping threshold # parser.add_argument('--restore_clip', action='store_true', # help='if true, restore the last step model from rep/attn kd for two stage kd') parser.add_argument('--clip_init_file', default='threshold_std.pkl', help='files to restore init clip values.') parser.add_argument( '--clip_init_val', default=2.5, type=float, help='init value of clip_vals, default to (-2.5, +2.5).') parser.add_argument('--clip_lr', default=1e-4, type=float, help='Use a seperate lr for clip_vals / stepsize') parser.add_argument('--clip_wd', default=0.0, type=float, help='weight decay for clip_vals / stepsize') # layerwise quantization config parser.add_argument('--embed_layerwise', default=False, type=lambda x: bool(int(x))) parser.add_argument('--weight_layerwise', default=True, type=lambda x: bool(int(x))) parser.add_argument('--input_layerwise', default=True, type=lambda x: bool(int(x))) ### spliting parser.add_argument( '--split', action='store_true', help= 'whether to conduct tws spliting. NOTE this is only for training binarybert' ) parser.add_argument('--is_binarybert', action='store_true', help='whether to use binarybert modelling.') args = parser.parse_args() args.do_lower_case = True log_dir = os.path.join(args.output_dir, 'record_%s.log' % args.job_id) init_logging(log_dir) # Prepare devices device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logging.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) task_name = args.task_name.lower() # restore the default setting if they are None if args.batch_size is None: if task_name in default_params: args.batch_size = default_params[task_name]["batch_size"] args.batch_size = int(args.batch_size * n_gpu) if args.max_seq_length == None: if task_name in default_params: args.max_seq_length = default_params[task_name]["max_seq_length"] if task_name not in processors: raise ValueError("Task not found: %s" % task_name) print_args(vars(args)) processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=args.do_lower_case) config = BertConfig.from_pretrained(args.teacher_model) config.num_labels = num_labels student_config = copy.deepcopy(config) student_config.weight_bits = args.weight_bits student_config.input_bits = args.input_bits student_config.weight_quant_method = args.weight_quant_method student_config.input_quant_method = args.input_quant_method student_config.clip_init_val = args.clip_init_val student_config.learnable_scaling = args.learnable_scaling student_config.sym_quant_qkvo = args.sym_quant_qkvo student_config.sym_quant_ffn_attn = args.sym_quant_ffn_attn student_config.embed_layerwise = args.embed_layerwise student_config.weight_layerwise = args.weight_layerwise student_config.input_layerwise = args.input_layerwise student_config.hidden_act = args.ACT2FN num_train_optimization_steps = 0 if not args.do_eval: if args.aug_train: train_examples = processor.get_aug_examples(args.data_dir) else: train_examples = processor.get_train_examples(args.data_dir) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.batch_size = args.batch_size // args.gradient_accumulation_steps train_features = convert_examples_to_features(train_examples, label_list, args.max_seq_length, tokenizer, output_mode) train_data, _ = get_tensor_data(output_mode, train_features) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) num_train_optimization_steps = int( len(train_features) / args.batch_size / args.gradient_accumulation_steps) * args.num_train_epochs eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) eval_data, eval_labels = get_tensor_data(output_mode, eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) if task_name == "mnli": processor = processors["mnli-mm"]() if not os.path.exists(args.output_dir + '-MM'): os.makedirs(args.output_dir + '-MM') mm_eval_examples = processor.get_dev_examples(args.data_dir) mm_eval_features = convert_examples_to_features( mm_eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) mm_eval_data, mm_eval_labels = get_tensor_data(output_mode, mm_eval_features) logging.info("***** Running mm evaluation *****") logging.info(" Num examples = %d", len(mm_eval_examples)) mm_eval_sampler = SequentialSampler(mm_eval_data) mm_eval_dataloader = DataLoader(mm_eval_data, sampler=mm_eval_sampler, batch_size=args.batch_size) else: mm_eval_labels = None mm_eval_dataloader = None if not args.do_eval: # need the teacher model for training teacher_model = BertForSequenceClassification.from_pretrained( args.teacher_model, config=config) teacher_model.to(device) if n_gpu > 1: teacher_model = torch.nn.DataParallel(teacher_model) else: teacher_model = None # logging.info("Rename the config and checkpoint to restore if necessary.") # if not os.path.isfile(os.path.join(args.student_model, 'config.json')): # os.system('cp -v %s/%s %s/%s' % (args.student_model, 'kd_stage2_config.json', args.student_model, 'config.json')) # if not os.path.isfile(os.path.join(args.student_model, 'pytorch_model.bin')): # os.system('cp -v %s/%s %s/%s' % (args.student_model, 'kd_stage2_pytorch_model.bin', args.student_model, 'pytorch_model.bin')) if args.split: # rename the checkpoint to restore split_model_dir = os.path.join(args.output_dir, 'binary_model_init') if not os.path.exists(split_model_dir): os.mkdir(split_model_dir) # copy the json file, avoid over-writing source_model_dir = os.path.join(args.student_model, CONFIG_NAME) target_model_dir = os.path.join(split_model_dir, CONFIG_NAME) os.system('cp -v %s %s' % (source_model_dir, target_model_dir)) # create the split model ckpt source_model_dir = os.path.join(args.student_model, WEIGHTS_NAME) target_model_dir = os.path.join(split_model_dir, WEIGHTS_NAME) target_model_dir = tws_split(source_model_dir, target_model_dir) args.student_model = split_model_dir # over-write student_model dir print( "transformed binary model stored at: {}".format(target_model_dir)) if args.is_binarybert: student_model = BertForSequenceClassification_binary.from_pretrained( args.student_model, config=student_config) else: student_model = QuantBertForSequenceClassification.from_pretrained( args.student_model, config=student_config) student_model.to(device) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) learner = KDLearner(args, device, student_model, teacher_model, num_train_optimization_steps) if args.do_eval: """ evaluation """ learner.evaluate(task_name, eval_dataloader, output_mode, eval_labels, num_labels, eval_examples, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) return 0 """ perform training """ if args.kd_type == 'joint_kd': learner.build() learner.train(train_examples, task_name, output_mode, eval_labels, num_labels, train_dataloader, eval_dataloader, eval_examples, tokenizer, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) elif args.kd_type == 'logit_kd': # only perform the logits kd learner.build(lr=args.learning_rate) learner.args.distill_logit = True learner.args.distill_rep_attn = False learner.train(train_examples, task_name, output_mode, eval_labels, num_labels, train_dataloader, eval_dataloader, eval_examples, tokenizer, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) elif args.kd_type == 'two_stage': # stage 1: intermediate layer distillation learner.args.distill_logit = False learner.args.distill_rep_attn = True learner.build(lr=2.5 * args.learning_rate) learner.train(train_examples, task_name, output_mode, eval_labels, num_labels, train_dataloader, eval_dataloader, eval_examples, tokenizer, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) # stage 2: prediction layer distillation learner.student_model.load_state_dict( torch.load(os.path.join(learner.output_dir, 'pytorch_model.bin'))) learner.args.distill_logit = True learner.args.distill_rep_attn = False learner.build(lr=args.learning_rate) # prepare the optimizer again. learner.train(train_examples, task_name, output_mode, eval_labels, num_labels, train_dataloader, eval_dataloader, eval_examples, tokenizer, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) else: assert args.kd_type == 'no_kd' # NO kd training, vanilla cross entropy with hard label learner.build(lr=args.learning_rate) # prepare the optimizer again. learner.train(train_examples, task_name, output_mode, eval_labels, num_labels, train_dataloader, eval_dataloader, eval_examples, tokenizer, mm_eval_dataloader=mm_eval_dataloader, mm_eval_labels=mm_eval_labels) del learner return 0
SEP = '[SEP]' MASK = '[MASK]' log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler('debug_layer_loss.log') fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logger = logging.getLogger() pretrained_bert_model = f"/rscratch/bohan/ZQBert/zero-shot-qbert/Berts/mrpc_base_l12/" #pretrained_bert_model = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(pretrained_bert_model) model = BertForMaskedLM.from_pretrained(pretrained_bert_model) mask_id = tokenizer.convert_tokens_to_ids([MASK])[0] sep_id = tokenizer.convert_tokens_to_ids([SEP])[0] cls_id = tokenizer.convert_tokens_to_ids([CLS])[0] model.eval() cuda = torch.cuda.is_available() if cuda: model = model.cuda() def tokenize_batch(batch): return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]