def main(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_bert_model", default=None, type=str, required=True, help= "Downloaded pretrained model (bert-base-cased/uncased) is under this folder" ) parser.add_argument("--glove_embs", default=None, type=str, required=True, help="Glove word embeddings file") parser.add_argument("--glue_dir", default=None, type=str, required=True, help="GLUE data dir") parser.add_argument( "--task_name", default=None, type=str, required=True, help= "Task(eg. CoLA, SST-2) that we want to do data augmentation for its train set" ) parser.add_argument("--N", default=30, type=int, help="How many times is the corpus expanded?") parser.add_argument( "--M", default=15, type=int, help="Choose from M most-likely words in the corresponding position") parser.add_argument("--p", default=0.4, type=float, help="Threshold probability p to replace current word") args = parser.parse_args() # logger.info(args) default_params = { "CoLA": { "N": 30 }, "MNLI": { "N": 10 }, "MRPC": { "N": 30 }, "SST-2": { "N": 20 }, "STS-b": { "N": 30 }, "QQP": { "N": 10 }, "QNLI": { "N": 20 }, "RTE": { "N": 30 } } if args.task_name in default_params: args.N = default_params[args.task_name]["N"] # Prepare data augmentor tokenizer = BertTokenizer.from_pretrained(args.pretrained_bert_model) model = BertForMaskedLM.from_pretrained(args.pretrained_bert_model) model.eval() emb_norm, vocab, ids_to_tokens = prepare_embedding_retrieval( args.glove_embs) data_augmentor = DataAugmentor(model, tokenizer, emb_norm, vocab, ids_to_tokens, args.M, args.N, args.p) # Do data augmentation processor = AugmentProcessor(data_augmentor, args.glue_dir, args.task_name) processor.read_augment_write()
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default='data', type=str, #required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") parser.add_argument("--model_dir", default='models/finetuned_teacher/', type=str, help="The teacher model dir.") parser.add_argument("--tasks", default='RTE,MRPC,STS-B,CoLA,SST-2,QNLI', type=str, #required=True, help="The name of the task to train.") parser.add_argument("--output_dir", default='output', type=str, #required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--max_seq_length", default=None, type=int, help="The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument("--do_lower_case",default = True, action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--root_dir", default='./', type=str) parser.add_argument("--log_dir", default='', type=str) parser.add_argument("--tensorboard_dir", default='', type=str) parser.add_argument("--model_save_dir", default='', type=str) args = parser.parse_args() logger.info('The args: {}'.format(args)) args.data_dir = os.path.join(args.root_dir,args.data_dir) args.model_dir = os.path.join(args.root_dir,args.model_dir) args.output_dir = os.path.join(args.model_save_dir,args.output_dir) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor, "wnli": WnliProcessor } output_modes = { "cola": "classification", "mnli": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification", "wnli": "classification" } default_params = { "cola": {"max_seq_length": 64}, "mnli": {"max_seq_length": 128}, "mrpc": {"max_seq_length": 128}, "sst-2": {"max_seq_length": 64}, "sts-b": {"max_seq_length": 128}, "qqp": {"max_seq_length": 128}, "qnli": {"max_seq_length": 128}, "rte": {"max_seq_length": 128} } infer_files = { "cola": "CoLA.tsv", "mnli": "MNLI-m.tsv", "mrpc": "MRPC.tsv", "sst-2": "SST-2.tsv", "sts-b": "STS-B.tsv", "qqp": "QQP.tsv", "qnli": "QNLI.tsv", "rte": "RTE.tsv", "wnli": "WNLI.tsv" } # Prepare devices device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logger.info("device: {} n_gpu: {}".format(device, n_gpu)) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tasks = args.tasks.lower() for task_name in tasks.split(','): data_dir = os.path.join(args.data_dir,task_name) model_dir = os.path.join(args.model_dir,task_name) if args.max_seq_length == None: if task_name in default_params: args.max_seq_length = default_params[task_name]["max_seq_length"] processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) output_file = os.path.join(args.output_dir,infer_files[task_name]) tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=args.do_lower_case) examples = processor.get_test_examples(data_dir) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode) data, labels = get_tensor_data(output_mode, features) sampler = SequentialSampler(data) dataloader = DataLoader(data, sampler=sampler, batch_size=args.batch_size) model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=num_labels,do_quantize = 0) model.to(device) logger.info("***** Running evaluation *****") logger.info(" Num examples = %d", len(examples)) logger.info(" Batch size = %d", args.batch_size) model.eval() do_infer(model, task_name, dataloader, device, output_mode, output_file,label_list) if task_name == "mnli": processor = processors["mnli-mm"]() examples = processor.get_test_examples(data_dir) output_file = os.path.join(args.output_dir,'MNLI-mm.tsv') features = convert_examples_to_features( examples, label_list, args.max_seq_length, tokenizer, output_mode) data, labels = get_tensor_data(output_mode, features) logger.info("***** Running mm evaluation *****") logger.info(" Num examples = %d", len(examples)) sampler = SequentialSampler(data) dataloader = DataLoader(data, sampler=sampler, batch_size=args.batch_size) do_infer(model, task_name, dataloader, device, output_mode, output_file,label_list)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", default='data', type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--model_dir", default='models/tinybert', type=str, help="The model dir.") parser.add_argument("--teacher_model", default=None, type=str, help="The models directory.") parser.add_argument("--student_model", default=None, type=str, help="The models directory.") parser.add_argument("--task_name", default='sst-2', type=str, help="The name of the task to train.") parser.add_argument( "--output_dir", default='output', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--aug_train', action='store_false', help="Whether to use augmented data or not") parser.add_argument('--pred_distill', action='store_true', help="Whether to distil with task layer") parser.add_argument('--intermediate_distill', action='store_true', help="Whether to distil with intermediate layers") parser.add_argument('--save_fp_model', action='store_true', help="Whether to save fp32 model") parser.add_argument('--save_quantized_model', action='store_true', help="Whether to save quantized model") parser.add_argument("--weight_bits", default=2, type=int, choices=[2, 8], help="Quantization bits for weight.") parser.add_argument("--input_bits", default=8, type=int, help="Quantization bits for activation.") parser.add_argument("--clip_val", default=2.5, type=float, help="Initial clip value.") args = parser.parse_args() assert args.pred_distill or args.intermediate_distill, "'pred_distill' and 'intermediate_distill', at least one must be True" summaryWriter = SummaryWriter(args.output_dir) logger.info('The args: {}'.format(args)) task_name = args.task_name.lower() data_dir = os.path.join(args.data_dir, task_name) output_dir = os.path.join(args.output_dir, task_name) # processed_data_dir = os.path.join(args.data_dir,'preprocessed',task_name) if not os.path.exists(output_dir): os.mkdir(output_dir) if args.student_model is None: args.student_model = os.path.join(args.model_dir, task_name) if args.teacher_model is None: args.teacher_model = os.path.join(args.model_dir, task_name) processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mnli-mm": MnliMismatchedProcessor, "mrpc": MrpcProcessor, "sst-2": Sst2Processor, "sts-b": StsbProcessor, "qqp": QqpProcessor, "qnli": QnliProcessor, "rte": RteProcessor } output_modes = { "cola": "classification", "mnli": "classification", "mrpc": "classification", "sst-2": "classification", "sts-b": "regression", "qqp": "classification", "qnli": "classification", "rte": "classification" } default_params = { "cola": { "max_seq_length": 64, "batch_size": 16, "eval_step": 50 }, "mnli": { "max_seq_length": 128, "batch_size": 32, "eval_step": 1000 }, "mrpc": { "max_seq_length": 128, "batch_size": 32, "eval_step": 200 }, "sst-2": { "max_seq_length": 64, "batch_size": 32, "eval_step": 200 }, "sts-b": { "max_seq_length": 128, "batch_size": 32, "eval_step": 50 }, "qqp": { "max_seq_length": 128, "batch_size": 32, "eval_step": 1000 }, "qnli": { "max_seq_length": 128, "batch_size": 32, "eval_step": 1000 }, "rte": { "max_seq_length": 128, "batch_size": 32, "eval_step": 100 } } acc_tasks = ["mnli", "mrpc", "sst-2", "qqp", "qnli", "rte"] corr_tasks = ["sts-b"] mcc_tasks = ["cola"] # Prepare devices device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if task_name in default_params: args.batch_size = default_params[task_name]["batch_size"] if n_gpu > 0: args.batch_size = int(args.batch_size * n_gpu) args.max_seq_length = default_params[task_name]["max_seq_length"] args.eval_step = default_params[task_name]["eval_step"] processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=True) if args.aug_train: try: train_file = os.path.join(processed_data_dir, 'aug_data') train_features = pickle.load(open(train_file, 'rb')) except: train_examples = processor.get_aug_examples(data_dir) train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer, output_mode) else: try: train_file = os.path.join(processed_data_dir, 'train_data') train_features = pickle.load(open(train_file, 'rb')) except: train_examples = processor.get_train_examples(data_dir) train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer, output_mode) num_train_optimization_steps = int( len(train_features) / args.batch_size) * args.num_train_epochs train_data, _ = get_tensor_data(output_mode, train_features) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) try: dev_file = train_file = os.path.join(processed_data_dir, 'dev_data') eval_features = pickle.load(open(dev_file, 'rb')) except: eval_examples = processor.get_dev_examples(data_dir) eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) eval_data, eval_labels = get_tensor_data(output_mode, eval_features) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) if task_name == "mnli": processor = processors["mnli-mm"]() try: dev_mm_file = train_file = os.path.join(processed_data_dir, 'dev-mm_data') mm_eval_features = pickle.load(open(dev_mm_file, 'rb')) except: mm_eval_examples = processor.get_dev_examples(data_dir) mm_eval_features = convert_examples_to_features( mm_eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) mm_eval_data, mm_eval_labels = get_tensor_data(output_mode, mm_eval_features) logger.info(" Num examples = %d", len(mm_eval_features)) mm_eval_sampler = SequentialSampler(mm_eval_data) mm_eval_dataloader = DataLoader(mm_eval_data, sampler=mm_eval_sampler, batch_size=args.batch_size) teacher_model = BertForSequenceClassification.from_pretrained( args.teacher_model) teacher_model.to(device) teacher_model.eval() if n_gpu > 1: teacher_model = torch.nn.DataParallel(teacher_model) result = do_eval(teacher_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) if task_name in acc_tasks: if task_name in ['sst-2', 'mnli', 'qnli', 'rte']: fp32_performance = f"acc:{result['acc']}" elif task_name in ['mrpc', 'qqp']: fp32_performance = f"f1/acc:{result['f1']}/{result['acc']}" if task_name in corr_tasks: fp32_performance = f"pearson/spearmanr:{result['pearson']}/{result['spearmanr']}" if task_name in mcc_tasks: fp32_performance = f"mcc:{result['mcc']}" if task_name == "mnli": result = do_eval(teacher_model, 'mnli-mm', mm_eval_dataloader, device, output_mode, mm_eval_labels, num_labels) fp32_performance += f" mm-acc:{result['acc']}" fp32_performance = task_name + ' fp32 ' + fp32_performance student_config = BertConfig.from_pretrained(args.teacher_model, quantize_act=True, weight_bits=args.weight_bits, input_bits=args.input_bits, clip_val=args.clip_val) student_model = QuantBertForSequenceClassification.from_pretrained( args.student_model, config=student_config, num_labels=num_labels) student_model.to(device) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_features)) logger.info(" Batch size = %d", args.batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=0.1, t_total=num_train_optimization_steps) loss_mse = MSELoss() global_step = 0 best_dev_acc = 0.0 previous_best = None tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. for epoch_ in range(int(args.num_train_epochs)): nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(train_dataloader): student_model.train() batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch att_loss = 0. rep_loss = 0. cls_loss = 0. loss = 0. student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) if args.pred_distill: if output_mode == "classification": cls_loss = soft_cross_entropy(student_logits, teacher_logits) elif output_mode == "regression": cls_loss = loss_mse(student_logits, teacher_logits) loss = cls_loss tr_cls_loss += cls_loss.item() if args.intermediate_distill: for student_att, teacher_att in zip(student_atts, teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss for student_rep, teacher_rep in zip(student_reps, teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss += rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() if n_gpu > 1: loss = loss.mean() loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 tr_loss += loss.item() nb_tr_examples += label_ids.size(0) nb_tr_steps += 1 if global_step % args.eval_step == 0 or global_step == num_train_optimization_steps - 1: logger.info("***** Running evaluation *****") logger.info(" {} step of {} steps".format( global_step, num_train_optimization_steps)) if previous_best is not None: logger.info( f"{fp32_performance}\nPrevious best = {previous_best}") student_model.eval() loss = tr_loss / (step + 1) cls_loss = tr_cls_loss / (step + 1) att_loss = tr_att_loss / (step + 1) rep_loss = tr_rep_loss / (step + 1) result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels) result['global_step'] = global_step result['cls_loss'] = cls_loss result['att_loss'] = att_loss result['rep_loss'] = rep_loss result['loss'] = loss summaryWriter.add_scalar('total_loss', loss, global_step) summaryWriter.add_scalars( 'distill_loss', { 'att_loss': att_loss, 'rep_loss': rep_loss, 'cls_loss': cls_loss }, global_step) if task_name == 'cola': summaryWriter.add_scalar('mcc', result['mcc'], global_step) elif task_name in [ 'sst-2', 'mnli', 'mnli-mm', 'qnli', 'rte', 'wnli' ]: summaryWriter.add_scalar('acc', result['acc'], global_step) elif task_name in ['mrpc', 'qqp']: summaryWriter.add_scalars( 'performance', { 'acc': result['acc'], 'f1': result['f1'], 'acc_and_f1': result['acc_and_f1'] }, global_step) else: summaryWriter.add_scalar('corr', result['corr'], global_step) save_model = False if task_name in acc_tasks and result['acc'] > best_dev_acc: if task_name in ['sst-2', 'mnli', 'qnli', 'rte']: previous_best = f"acc:{result['acc']}" elif task_name in ['mrpc', 'qqp']: previous_best = f"f1/acc:{result['f1']}/{result['acc']}" best_dev_acc = result['acc'] save_model = True if task_name in corr_tasks and result['corr'] > best_dev_acc: previous_best = f"pearson/spearmanr:{result['pearson']}/{result['spearmanr']}" best_dev_acc = result['corr'] save_model = True if task_name in mcc_tasks and result['mcc'] > best_dev_acc: previous_best = f"mcc:{result['mcc']}" best_dev_acc = result['mcc'] save_model = True if save_model: # Test mnli-mm if task_name == "mnli": result = do_eval(student_model, 'mnli-mm', mm_eval_dataloader, device, output_mode, mm_eval_labels, num_labels) previous_best += f"mm-acc:{result['acc']}" logger.info(fp32_performance) logger.info(previous_best) if args.save_fp_model: logger.info( "******************** Save full precision model ********************" ) model_to_save = student_model.module if hasattr( student_model, 'module') else student_model output_model_file = os.path.join( output_dir, WEIGHTS_NAME) output_config_file = os.path.join( output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(output_dir) if args.save_quantized_model: logger.info( "******************** Save quantized model ********************" ) output_quant_dir = os.path.join(output_dir, 'quant') if not os.path.exists(output_quant_dir): os.makedirs(output_quant_dir) model_to_save = student_model.module if hasattr( student_model, 'module') else student_model quant_model = copy.deepcopy(model_to_save) for name, module in quant_model.named_modules(): if hasattr(module, 'weight_quantizer'): module.weight.data = module.weight_quantizer.apply( module.weight, module.weight_clip_val, module.weight_bits, True) output_model_file = os.path.join( output_quant_dir, WEIGHTS_NAME) output_config_file = os.path.join( output_quant_dir, CONFIG_NAME) torch.save(quant_model.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(output_quant_dir)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default='data/', type=str, help="The data directory.") parser.add_argument("--model_dir", default='models/', type=str, help="The models directory.") parser.add_argument("--teacher_model", default=None, type=str, help="The models directory.") parser.add_argument("--student_model", default=None, type=str, help="The models directory.") parser.add_argument( "--output_dir", default='output', type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument('--version_2_with_negative', action='store_true', help="Squadv2.0 if true else Squadv1.1 ") # default parser.add_argument( "--max_seq_length", default=384, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument( "--doc_stride", default=128, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument( "--max_query_length", default=64, type=int, help= "The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument( "--n_best_size", default=20, type=int, help= "The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument( "--max_answer_length", default=30, type=int, help= "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument("--verbose_logging", default=0, type=int) parser.add_argument( '--null_score_diff_threshold', type=float, default=0.0, help= "If null_score - best_non_null is greater than the threshold predict null." ) parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--do_lower_case', #action='store_true', default=True, help="do lower case") parser.add_argument("--per_gpu_batch_size", default=16, type=int, help="Per GPU batch size for training.") parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument('--eval_step', type=int, default=200, help="Evaluate every X training steps") parser.add_argument('--pred_distill', action='store_true', help="Whether to distil with task layer") parser.add_argument('--intermediate_distill', action='store_true', help="Whether to distil with intermediate layers") parser.add_argument('--save_fp_model', action='store_true', help="Whether to save fp32 model") parser.add_argument('--save_quantized_model', action='store_true', help="Whether to save quantized model") parser.add_argument("--weight_bits", default=2, type=int, choices=[2, 8], help="Quantization bits for weight.") parser.add_argument("--input_bits", default=8, type=int, help="Quantization bits for activation.") parser.add_argument("--clip_val", default=2.5, type=float, help="Initial clip value.") args = parser.parse_args() summaryWriter = SummaryWriter(args.output_dir) if args.teacher_model is None: args.teacher_model = args.model_dir if args.student_model is None: args.student_model = args.model_dir args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args.n_gpu = torch.cuda.device_count() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) args.batch_size = args.n_gpu * args.per_gpu_batch_size logger.info(f'The args: {args}') if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) tokenizer = BertTokenizer.from_pretrained(args.teacher_model, do_lower_case=True) # preparing training data input_file = 'train-v2.0' if args.version_2_with_negative else 'train-v1.1' input_file = os.path.join(args.data_dir, input_file) if os.path.exists(input_file): train_features = pickle.load(open(input_file, 'rb')) else: input_file = 'train-v2.0.json' if args.version_2_with_negative else 'train-v1.1.json' input_file = os.path.join(args.data_dir, input_file) _, train_examples = read_squad_examples( input_file=input_file, is_training=True, version_2_with_negative=args.version_2_with_negative) train_features = convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=True) num_train_optimization_steps = int( len(train_features) / args.batch_size) * args.num_train_epochs logger.info("***** Running training *****") logger.info(" Num split examples = %d", len(train_features)) logger.info(" Batch size = %d", args.batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor( [f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size) input_file = 'dev-v2.0.json' if args.version_2_with_negative else 'dev-v1.1.json' args.dev_file = os.path.join(args.data_dir, input_file) dev_dataset, eval_examples = read_squad_examples( input_file=args.dev_file, is_training=False, version_2_with_negative=args.version_2_with_negative) eval_features = convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=False) logger.info("***** Running predictions *****") logger.info(" Num orig examples = %d", len(eval_examples)) logger.info(" Num split examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.batch_size) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) teacher_model = BertForQuestionAnswering.from_pretrained( args.teacher_model) teacher_model.to(args.device) teacher_model.eval() if args.n_gpu > 1: teacher_model = torch.nn.DataParallel(teacher_model) result = do_eval(args, teacher_model, eval_dataloader, eval_features, eval_examples, args.device, dev_dataset) em, f1 = result['exact_match'], result['f1'] logger.info(f"Full precision teacher exact_match={em},f1={f1}") student_config = BertConfig.from_pretrained(args.student_model, quantize_act=True, weight_bits=args.weight_bits, input_bits=args.input_bits, clip_val=args.clip_val) student_model = QuantBertForQuestionAnswering.from_pretrained( args.student_model, config=student_config) student_model.to(args.device) if args.n_gpu > 1: student_model = torch.nn.DataParallel(student_model) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=0.1, t_total=num_train_optimization_steps) loss_mse = MSELoss() # Train and evaluate global_step = 0 best_dev_f1 = 0.0 flag_loss = float('inf') previous_best = None tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. for epoch_ in range(int(args.num_train_epochs)): for step, batch in enumerate(train_dataloader): student_model.train() batch = tuple(t.to(args.device) for t in batch) input_ids, input_mask, segment_ids, start_positions, end_positions = batch att_loss = 0. rep_loss = 0. cls_loss = 0. loss = 0 student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) if args.pred_distill: soft_start_ce_loss = soft_cross_entropy( student_logits[0], teacher_logits[0]) soft_end_ce_loss = soft_cross_entropy(student_logits[1], teacher_logits[1]) cls_loss = soft_start_ce_loss + soft_end_ce_loss loss += cls_loss tr_cls_loss += cls_loss.item() if args.intermediate_distill: for student_att, teacher_att in zip(student_atts, teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(args.device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(args.device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss for student_rep, teacher_rep in zip(student_reps, teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss loss += rep_loss + att_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() if args.n_gpu > 1: loss = loss.mean() loss.backward() tr_loss += loss.item() optimizer.step() optimizer.zero_grad() global_step += 1 save_model = False if global_step % args.eval_step == 0 or global_step == num_train_optimization_steps - 1: logger.info("***** Running evaluation *****") logger.info(f" Epoch = {epoch_} iter {global_step} step") if previous_best is not None: logger.info(f"Previous best = {previous_best}") student_model.eval() result = do_eval(args, student_model, eval_dataloader, eval_features, eval_examples, args.device, dev_dataset) em, f1 = result['exact_match'], result['f1'] logger.info(f'{em}/{f1}') if f1 > best_dev_f1: previous_best = f"exact_match={em},f1={f1}" best_dev_f1 = f1 save_model = True summaryWriter.add_scalars('performance', { 'exact_match': em, 'f1': f1 }, global_step) loss = tr_loss / global_step cls_loss = tr_cls_loss / global_step att_loss = tr_att_loss / global_step rep_loss = tr_rep_loss / global_step summaryWriter.add_scalar('total_loss', loss, global_step) summaryWriter.add_scalars( 'distill_loss', { 'att_loss': att_loss, 'rep_loss': rep_loss, 'cls_loss': cls_loss }, global_step) #save quantiozed model if save_model: logger.info(previous_best) if args.save_fp_model: logger.info( "******************** Save full precision model ********************" ) model_to_save = student_model.module if hasattr( student_model, 'module') else student_model output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) if args.save_quantized_model: logger.info( "******************** Save quantized model ********************" ) output_quant_dir = os.path.join(args.output_dir, 'quant') if not os.path.exists(output_quant_dir): os.makedirs(output_quant_dir) model_to_save = student_model.module if hasattr( student_model, 'module') else student_model quant_model = copy.deepcopy(model_to_save) for name, module in quant_model.named_modules(): if hasattr(module, 'weight_quantizer'): module.weight.data = module.weight_quantizer.apply( module.weight, module.weight_clip_val, module.weight_bits, True) output_model_file = os.path.join(output_quant_dir, WEIGHTS_NAME) output_config_file = os.path.join(output_quant_dir, CONFIG_NAME) torch.save(quant_model.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(output_quant_dir)