class MTDNNModel(object): def __init__(self, opt, state_dict=None, num_train_step=-1): self.config = opt self.updates = state_dict[ 'updates'] if state_dict and 'updates' in state_dict else 0 self.train_loss = AverageMeter() self.network = SANBertNetwork(opt) # pdb.set_trace() if state_dict: new_state = set(self.network.state_dict().keys()) # change to a safer approach old_keys = [k for k in state_dict['state'].keys()] for k in old_keys: if k not in new_state: print('deleting state:', k) del state_dict['state'][k] for k, v in list(self.network.state_dict().items()): if k not in state_dict['state']: print('adding missing state:', k) state_dict['state'][k] = v # pdb.set_trace() self.network.load_state_dict(state_dict['state']) self.mnetwork = nn.DataParallel( self.network) if opt['multi_gpu_on'] else self.network self.total_param = sum([ p.nelement() for p in self.network.parameters() if p.requires_grad ]) no_decay = [ 'bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight' ] optimizer_parameters = [{ 'params': [ p for n, p in self.network.named_parameters() if n not in no_decay ], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in self.network.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0 }] # note that adamax are modified based on the BERT code if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(optimizer_parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = Adamax(optimizer_parameters, opt['learning_rate'], warmup=opt['warmup'], t_total=num_train_step, max_grad_norm=opt['grad_clipping'], schedule=opt['warmup_schedule']) if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(optimizer_parameters, opt['learning_rate'], rho=0.95) elif opt['optimizer'] == 'adam': self.optimizer = Adam(optimizer_parameters, lr=opt['learning_rate'], warmup=opt['warmup'], t_total=num_train_step, max_grad_norm=opt['grad_clipping'], schedule=opt['warmup_schedule']) if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if state_dict and 'optimizer' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) if opt.get('have_lr_scheduler', False): if opt.get('scheduler_type', 'rop') == 'rop': self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3) elif opt.get('scheduler_type', 'rop') == 'exp': self.scheduler = ExponentialLR(self.optimizer, gamma=opt.get('lr_gamma', 0.95)) else: milestones = [ int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',') ] self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma')) else: self.scheduler = None self.ema = None if opt['ema_opt'] > 0: self.ema = EMA(self.config['ema_gamma'], self.network) self.para_swapped = False def setup_ema(self): if self.config['ema_opt']: self.ema.setup() def update_ema(self): if self.config['ema_opt']: self.ema.update() def eval(self): if self.config['ema_opt']: self.ema.swap_parameters() self.para_swapped = True def train(self): if self.para_swapped: self.ema.swap_parameters() self.para_swapped = False def update(self, batch_meta, batch_data): self.network.train() labels = batch_data[batch_meta['label']] # print('data size:',batch_data[batch_meta['token_id']].size()) if batch_meta['pairwise']: labels = labels.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0] if self.config['cuda']: y = Variable(labels.cuda(async=True), requires_grad=False) else: y = Variable(labels, requires_grad=False) task_id = batch_meta['task_id'] task_type = batch_meta['task_type'] inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) # pdb.set_trace() logits = self.mnetwork(*inputs) if batch_meta['pairwise']: logits = logits.view(-1, batch_meta['pairwise_size']) # pdb.set_trace() if task_type > 0: if self.config['answer_relu']: logits = F.relu(logits) loss = F.mse_loss(logits.squeeze(1), y) else: loss = F.cross_entropy(logits, y) if self.config['mediqa_pairloss'] is not None and batch_meta[ 'dataset_name'] in mediqa_name_list: # print(logits) # print(batch_data[batch_meta['rank_label']].size()) # input('ha') logits = logits.squeeze().view(-1, 2) # print(batch_data[batch_meta['rank_label']]) rank_y = batch_data[batch_meta['rank_label']].view(-1, 2) # print(rank_y) if self.config['mediqa_pairloss'] == 'hinge': # print(logits) first_logit, second_logit = logits.split(1, dim=1) # print(first_logit,second_logit) # pdb.set_trace() rank_y = (2 * rank_y - 1).to(torch.float32) rank_y = rank_y[:, 0] pairwise_loss = F.margin_ranking_loss( first_logit.squeeze(1), second_logit.squeeze(1), rank_y, margin=self.config['hinge_lambda']) else: # pdb.set_trace() pairwise_loss = F.cross_entropy(logits, rank_y[:, 1]) # print('pairwise_loss:',pairwise_loss,'mse loss:',loss) loss += pairwise_loss self.train_loss.update(loss.item(), logits.size(0)) self.optimizer.zero_grad() loss.backward() if self.config['global_grad_clipping'] > 0: torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.config['global_grad_clipping']) self.optimizer.step() self.updates += 1 self.update_ema() def predict(self, batch_meta, batch_data): self.network.eval() task_id = batch_meta['task_id'] task_type = batch_meta['task_type'] inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) score = self.mnetwork(*inputs) gold_label = batch_meta['label'] if batch_meta['pairwise']: score = score.contiguous().view(-1, batch_meta['pairwise_size']) if task_type < 1: score = F.softmax(score, dim=1) score = score.data.cpu() score = score.numpy() predict = np.zeros(score.shape, dtype=int) if task_type < 1: positive = np.argmax(score, axis=1) for idx, pos in enumerate(positive): predict[idx, pos] = 1 predict = predict.reshape(-1).tolist() score = score.reshape(-1).tolist() return score, predict, batch_meta['true_label'] else: if task_type < 1: score = F.softmax(score, dim=1) # pdb.set_trace() score = score.data.cpu() score = score.numpy() if task_type < 1: predict = np.argmax(score, axis=1).tolist() else: predict = np.greater( score, 2.0 + self.config['mediqa_score_offset']).astype(int) gold_label = np.greater( batch_meta['label'], 2.00001 + self.config['mediqa_score_offset']).astype(int) predict = predict.reshape(-1).tolist() gold_label = gold_label.reshape(-1).tolist() # print('predict:',predict,score) score = score.reshape(-1).tolist() return score, predict, gold_label def save(self, filename): network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()]) ema_state = dict([ (k, v.cpu()) for k, v in self.ema.model.state_dict().items() ]) if self.ema is not None else dict() params = { 'state': network_state, 'optimizer': self.optimizer.state_dict(), 'ema': ema_state, 'config': self.config, } torch.save(params, filename) logger.info('model saved to {}'.format(filename)) def cuda(self): self.network.cuda() if self.config['ema_opt']: self.ema.cuda()
def main(*_, **kwargs): use_cuda = torch.cuda.is_available() and kwargs["device"] >= 0 device = torch.device("cuda:" + str(kwargs["device"]) if use_cuda else "cpu") if use_cuda: torch.cuda.set_device(device) kwargs["use_cuda"] = use_cuda neptune.create_experiment( name="bert-span-parser", upload_source_files=[], params={ k: str(v) if isinstance(v, bool) else v for k, v in kwargs.items() }, ) logger.info("Settings: {}", json.dumps(kwargs, indent=2, ensure_ascii=False)) # For reproducibility os.environ["PYTHONHASHSEED"] = str(kwargs["seed"]) random.seed(kwargs["seed"]) np.random.seed(kwargs["seed"]) torch.manual_seed(kwargs["seed"]) torch.cuda.manual_seed_all(kwargs["seed"]) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Prepare and load data tokenizer = BertTokenizer.from_pretrained(kwargs["bert_model"], do_lower_case=False) logger.info("Loading data...") train_treebank = load_trees(kwargs["train_file"]) dev_treebank = load_trees(kwargs["dev_file"]) test_treebank = load_trees(kwargs["test_file"]) logger.info( "Loaded {:,} train, {:,} dev, and {:,} test examples!", len(train_treebank), len(dev_treebank), len(test_treebank), ) logger.info("Preprocessing data...") train_parse = [tree.convert() for tree in train_treebank] train_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in train_parse] dev_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in dev_treebank] test_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in test_treebank] logger.info("Data preprocessed!") logger.info("Preparing data for training...") tags = [] labels = [] for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, InternalParseNode): labels.append(node.label) nodes.extend(reversed(node.children)) else: tags.append(node.tag) tag_encoder = LabelEncoder() tag_encoder.fit(tags, reserved_labels=["[PAD]", "[UNK]"]) label_encoder = LabelEncoder() label_encoder.fit(labels, reserved_labels=[()]) logger.info("Data prepared!") # Settings num_train_optimization_steps = kwargs["num_epochs"] * ( (len(train_parse) - 1) // kwargs["batch_size"] + 1) kwargs["batch_size"] //= kwargs["gradient_accumulation_steps"] logger.info("Creating dataloaders for training...") train_dataloader, train_features = create_dataloader( sentences=train_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=False, ) dev_dataloader, dev_features = create_dataloader( sentences=dev_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=True, ) test_dataloader, test_features = create_dataloader( sentences=test_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=True, ) logger.info("Dataloaders created!") # Initialize model model = ChartParser.from_pretrained( kwargs["bert_model"], tag_encoder=tag_encoder, label_encoder=label_encoder, lstm_layers=kwargs["lstm_layers"], lstm_dim=kwargs["lstm_dim"], tag_embedding_dim=kwargs["tag_embedding_dim"], label_hidden_dim=kwargs["label_hidden_dim"], dropout_prob=kwargs["dropout_prob"], ) model.to(device) # Prepare optimizer param_optimizers = list(model.named_parameters()) if kwargs["freeze_bert"]: for p in model.bert.parameters(): p.requires_grad = False param_optimizers = [(n, p) for n, p in param_optimizers if p.requires_grad] # Hack to remove pooler, which is not used thus it produce None grad that break apex param_optimizers = [n for n in param_optimizers if "pooler" not in n[0]] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizers if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in param_optimizers if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = BertAdam( optimizer_grouped_parameters, lr=kwargs["learning_rate"], warmup=kwargs["warmup_proportion"], t_total=num_train_optimization_steps, ) if kwargs["fp16"]: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") pretrained_model_file = os.path.join(kwargs["output_dir"], MODEL_FILENAME) if kwargs["do_eval"]: assert os.path.isfile( pretrained_model_file), "Pretrained model file does not exist!" logger.info("Loading pretrained model from {}", pretrained_model_file) # Load model from file params = torch.load(pretrained_model_file, map_location=device) model.load_state_dict(params["model"]) logger.info( "Loaded pretrained model (Epoch: {:,}, Fscore: {:.2f})", params["epoch"], params["fscore"], ) eval_score = eval( model=model, eval_dataloader=test_dataloader, eval_features=test_features, eval_trees=test_treebank, eval_sentences=test_sentences, tag_encoder=tag_encoder, device=device, ) neptune.send_metric("test_eval_precision", eval_score.precision()) neptune.send_metric("test_eval_recall", eval_score.recall()) neptune.send_metric("test_eval_fscore", eval_score.fscore()) tqdm.write("Evaluation score: {}".format(str(eval_score))) else: # Training phase global_steps = 0 start_epoch = 0 best_dev_fscore = 0 if kwargs["preload"] or kwargs["resume"]: assert os.path.isfile( pretrained_model_file), "Pretrained model file does not exist!" logger.info("Resuming model from {}", pretrained_model_file) # Load model from file params = torch.load(pretrained_model_file, map_location=device) model.load_state_dict(params["model"]) if kwargs["resume"]: optimizer.load_state_dict(params["optimizer"]) torch.cuda.set_rng_state_all([ state.cpu() for state in params["torch_cuda_random_state_all"] ]) torch.set_rng_state(params["torch_random_state"].cpu()) np.random.set_state(params["np_random_state"]) random.setstate(params["random_state"]) global_steps = params["global_steps"] start_epoch = params["epoch"] + 1 best_dev_fscore = params["fscore"] else: assert not os.path.isfile( pretrained_model_file ), "Please remove or move the pretrained model file to another place!" for epoch in trange(start_epoch, kwargs["num_epochs"], desc="Epoch"): model.train() train_loss = 0 num_train_steps = 0 for step, (indices, *_) in enumerate( tqdm(train_dataloader, desc="Iteration")): ids, attention_masks, tags, sections, trees, sentences = prepare_batch_input( indices=indices, features=train_features, trees=train_parse, sentences=train_sentences, tag_encoder=tag_encoder, device=device, ) loss = model( ids=ids, attention_masks=attention_masks, tags=tags, sections=sections, sentences=sentences, gold_trees=trees, ) if kwargs["gradient_accumulation_steps"] > 1: loss /= kwargs["gradient_accumulation_steps"] if kwargs["fp16"]: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_loss += loss.item() num_train_steps += 1 if (step + 1) % kwargs["gradient_accumulation_steps"] == 0: optimizer.step() optimizer.zero_grad() global_steps += 1 # Write logs neptune.send_metric("train_loss", epoch, train_loss / num_train_steps) neptune.send_metric("global_steps", epoch, global_steps) tqdm.write( "Epoch: {:,} - Train loss: {:.4f} - Global steps: {:,}".format( epoch, train_loss / num_train_steps, global_steps)) # Evaluate eval_score = eval( model=model, eval_dataloader=dev_dataloader, eval_features=dev_features, eval_trees=dev_treebank, eval_sentences=dev_sentences, tag_encoder=tag_encoder, device=device, ) neptune.send_metric("eval_precision", epoch, eval_score.precision()) neptune.send_metric("eval_recall", epoch, eval_score.recall()) neptune.send_metric("eval_fscore", epoch, eval_score.fscore()) tqdm.write("Epoch: {:,} - Evaluation score: {}".format( epoch, str(eval_score))) # Save best model if eval_score.fscore() > best_dev_fscore: best_dev_fscore = eval_score.fscore() tqdm.write("** Saving model...") os.makedirs(kwargs["output_dir"], exist_ok=True) torch.save( { "epoch": epoch, "global_steps": global_steps, "fscore": best_dev_fscore, "random_state": random.getstate(), "np_random_state": np.random.get_state(), "torch_random_state": torch.get_rng_state(), "torch_cuda_random_state_all": torch.cuda.get_rng_state_all(), "optimizer": optimizer.state_dict(), "model": (model.module if hasattr(model, "module") else model).state_dict(), }, pretrained_model_file, ) tqdm.write( "** Best evaluation fscore: {:.2f}".format(best_dev_fscore))
best_acc = 0 for epoch in range(args.epoch_num): ## Train model.train() t = trange(args.steps_per_epoch, desc='Epoch {} -Train'.format(epoch)) loss_avg = utils.RunningAverage() train_iters = [iter(tmp) for tmp in train_bls ] # to use next and reset the iterator for i in t: task_id = train_task_ids[i] batch_data = next(train_iters[task_id]) batch_data = tuple(tmp.to(args.device) for tmp in batch_data) loss = model(batch_data, task_id, True) loss.backward() optimizer.step() optimizer.zero_grad() loss_avg.update(loss.item()) t.set_postfix(loss='{:5.4f}'.format(loss.item()), avg_loss='{:5.4f}'.format(loss_avg())) acc = eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args) utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict() }, is_best=acc > best_acc, checkpoint=args.model_dir) best_acc = max(best_acc, acc)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", default=None, type=str, required=True, help="The GPU device you will run on.") parser.add_argument( "--features_file", default=None, type=str, required=True, help= "The train features file. Should contain the .csv files (after tokenized) for the task." "Format: example_id,input_ids,input_mask,segment_ids,label\n") parser.add_argument( "--teacher_model", default=None, type=str, help= "The teacher model dir. Should contain the config/vocab/checkpoint file." ) parser.add_argument( "--general_student_model", default=None, type=str, required=True, help="The student model (after general distillation) dir. " "Should contain the config/vocab/checkpoint file.") parser.add_argument( "--output_student_dir", default=None, type=str, required=True, help= "The output directory for the task-specific distilled student models.") parser.add_argument("--cache_file_dir", default='./cache', type=str, required=True, help="The directory where cache the features.") parser.add_argument( "--distill_model", default='simplified', type=str, help="The distill model type, choose in 'standard' and 'simplified'.") parser.add_argument( "--max_seq_length", default=256, type=int, help= "The maximum total input sequence length after WordPiece tokenization." ) parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--train_batch_size", default=64, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument('--weight_decay', '--wd', default=1e-2, type=float, metavar='W', help='weight decay') parser.add_argument("--num_train_epochs", default=2, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--alpha", default=0.5, type=float, help="The weight of soft loss in standard kd method." "Only use when '--distill_model' is set as 'standard'.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument('--seed', type=int, default=42, help="Random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument( '--train_loss_step', type=int, default=1000, help="How many train step to record a training loss. ") parser.add_argument('--save_model_step', type=int, default=3000, help="How many train step to save a student model.") parser.add_argument('--temperature', type=float, default=1., help="The temperature in soft loss.") parser.add_argument( '--fp16', action='store_true', help= "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit." ) parser.add_argument( '--fp16_opt_level', type=str, default='O1', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") args = parser.parse_args() logger.info('The args: {}'.format(args)) # Prepare device os.environ["CUDA_VISIBLE_DEVICES"] = args.device device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() logger.info("device: {} n_gpu: {}".format(device, n_gpu)) # Prepare seed random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare task settings if os.path.exists(args.output_student_dir) and os.listdir( args.output_student_dir): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_student_dir)) if not os.path.exists(args.output_student_dir): os.makedirs(args.output_student_dir) if not os.path.exists(args.cache_file_dir): os.makedirs(args.cache_file_dir) # For save vocab file for all output models. tokenizer = BertTokenizer.from_pretrained(args.general_student_model, do_lower_case=args.do_lower_case) # Model teacher_model = TinyBertForSequenceClassification.from_pretrained( args.teacher_model, num_labels=2) if args.fp16: teacher_model.half() teacher_model.to(device) student_model = TinyBertForSequenceClassification.from_pretrained( args.general_student_model, num_labels=2) student_model.to(device) # Train Config num_examples, train_dataloader = distill_dataloader( args, RandomSampler, batch_size=args.train_batch_size) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) num_train_optimization_steps = int( num_examples / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs logger.info("***** Running Distilling *****") logger.info(" Num examples = %d", num_examples) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_optimization_steps) # Prepare optimizer param_optimizer = list(student_model.named_parameters()) size = 0 for n, p in student_model.named_parameters(): logger.info('n: {}'.format(n)) size += p.nelement() logger.info('Total parameters of student_model: {}'.format(size)) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] schedule = 'warmup_linear' optimizer = BertAdam(optimizer_grouped_parameters, schedule=schedule, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) student_model, optimizer = amp.initialize( student_model, optimizer, opt_level=args.fp16_opt_level) logger.info('FP16 is activated, use amp') else: logger.info('FP16 is not activated, only use BertAdam') if n_gpu > 1: student_model = torch.nn.DataParallel(student_model) teacher_model = torch.nn.DataParallel(teacher_model) # Prepare loss functions loss_mse = MSELoss() def soft_cross_entropy(predicts, targets): student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) targets_prob = torch.nn.functional.softmax(targets, dim=-1) return (-targets_prob * student_likelihood).mean() # Train global_step = 0 output_loss_file = os.path.join(args.output_student_dir, "train_loss.txt") tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. for epoch in trange(int(args.num_train_epochs), desc="Epoch"): student_model.train() for step, batch in enumerate( tqdm(train_dataloader, desc="Iteration", ascii=True)): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, label_ids = batch if input_ids.size()[0] != args.train_batch_size: continue student_logits, student_atts, student_reps = student_model( input_ids, segment_ids, input_mask, is_student=True) with torch.no_grad(): teacher_logits, teacher_atts, teacher_reps = teacher_model( input_ids, segment_ids, input_mask) soft_loss = soft_cross_entropy(student_logits / args.temperature, teacher_logits / args.temperature) hard_loss = torch.nn.functional.cross_entropy(student_logits, label_ids, reduction='mean') if args.distill_model == 'standard': cls_loss = args.alpha * soft_loss + (1 - args.alpha) * hard_loss tr_cls_loss += cls_loss.item() loss = cls_loss elif args.distill_model == 'simplified': teacher_layer_num = len(teacher_atts) student_layer_num = len(student_atts) assert teacher_layer_num % student_layer_num == 0 layers_per_block = int(teacher_layer_num / student_layer_num) new_teacher_atts = [ teacher_atts[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num) ] att_loss = 0. rep_loss = 0. # attention loss for student_att, teacher_att in zip(student_atts, new_teacher_atts): student_att = torch.where( student_att <= -1e2, torch.zeros_like(student_att).to(device), student_att) teacher_att = torch.where( teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device), teacher_att) tmp_loss = loss_mse(student_att, teacher_att) att_loss += tmp_loss # hidden states loss new_teacher_reps = [ teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) ] new_student_reps = student_reps for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps): tmp_loss = loss_mse(student_rep, teacher_rep) rep_loss += tmp_loss tr_att_loss += att_loss.item() tr_rep_loss += rep_loss.item() # classification loss cls_loss = soft_loss + hard_loss tr_cls_loss += cls_loss.item() # total loss loss = rep_loss + att_loss + cls_loss else: raise NotImplementedError if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() global_step += 1 if global_step % args.train_loss_step == 0: loss = tr_loss / args.train_loss_step cls_loss = tr_cls_loss / args.train_loss_step att_loss = tr_att_loss / args.train_loss_step rep_loss = tr_rep_loss / args.train_loss_step loss_dict = {} loss_dict['global_step'] = global_step loss_dict['cls_loss'] = cls_loss loss_dict['att_loss'] = att_loss loss_dict['rep_loss'] = rep_loss loss_dict['loss'] = loss write_loss_to_file(loss_dict, output_loss_file) tr_loss = 0. tr_att_loss = 0. tr_rep_loss = 0. tr_cls_loss = 0. if global_step % args.save_model_step == 0: logger.info("***** Save model *****") model_to_save = student_model.module if hasattr( student_model, 'module') else student_model model_name = WEIGHTS_NAME checkpoint_name = 'checkpoint-' + str(global_step) output_model_dir = os.path.join(args.output_dir, checkpoint_name) if not os.path.exists(output_model_dir): os.makedirs(output_model_dir) output_model_file = os.path.join(output_model_dir, model_name) output_config_file = os.path.join(output_model_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(output_model_dir) if os.path.exists(args.cache_file_dir): import shutil shutil.rmtree(args.cache_file_dir)