batch_data = next(train_iters[task_id]) batch_data = tuple(tmp.to(args.device) for tmp in batch_data) loss = model(batch_data, task_id, True) loss.backward() optimizer.step() optimizer.zero_grad() loss_avg.update(loss.item()) t.set_postfix(loss='{:5.4f}'.format(loss.item()), avg_loss='{:5.4f}'.format(loss_avg())) acc = eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args) utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict() }, is_best=acc > best_acc, checkpoint=args.model_dir) best_acc = max(best_acc, acc) if args.do_eval: logging.info('num of batch for dev: {}'.format(len(dev_bl))) utils.load_checkpoint(os.path.join(args.model_dir, 'best.pth.tar'), model) eval(model, ner_dev_data, dev_data, dev_bl, graph, entity_linking, args) if args.do_predict: utils.load_checkpoint(os.path.join(args.model_dir, 'best.pth.tar'), model)
class MTDNNModel(object): def __init__(self, opt, state_dict=None, num_train_step=-1): self.config = opt self.updates = state_dict[ 'updates'] if state_dict and 'updates' in state_dict else 0 self.train_loss = AverageMeter() self.network = SANBertNetwork(opt) # pdb.set_trace() if state_dict: new_state = set(self.network.state_dict().keys()) # change to a safer approach old_keys = [k for k in state_dict['state'].keys()] for k in old_keys: if k not in new_state: print('deleting state:', k) del state_dict['state'][k] for k, v in list(self.network.state_dict().items()): if k not in state_dict['state']: print('adding missing state:', k) state_dict['state'][k] = v # pdb.set_trace() self.network.load_state_dict(state_dict['state']) self.mnetwork = nn.DataParallel( self.network) if opt['multi_gpu_on'] else self.network self.total_param = sum([ p.nelement() for p in self.network.parameters() if p.requires_grad ]) no_decay = [ 'bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight' ] optimizer_parameters = [{ 'params': [ p for n, p in self.network.named_parameters() if n not in no_decay ], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in self.network.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0 }] # note that adamax are modified based on the BERT code if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(optimizer_parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = Adamax(optimizer_parameters, opt['learning_rate'], warmup=opt['warmup'], t_total=num_train_step, max_grad_norm=opt['grad_clipping'], schedule=opt['warmup_schedule']) if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(optimizer_parameters, opt['learning_rate'], rho=0.95) elif opt['optimizer'] == 'adam': self.optimizer = Adam(optimizer_parameters, lr=opt['learning_rate'], warmup=opt['warmup'], t_total=num_train_step, max_grad_norm=opt['grad_clipping'], schedule=opt['warmup_schedule']) if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if state_dict and 'optimizer' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) if opt.get('have_lr_scheduler', False): if opt.get('scheduler_type', 'rop') == 'rop': self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3) elif opt.get('scheduler_type', 'rop') == 'exp': self.scheduler = ExponentialLR(self.optimizer, gamma=opt.get('lr_gamma', 0.95)) else: milestones = [ int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',') ] self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma')) else: self.scheduler = None self.ema = None if opt['ema_opt'] > 0: self.ema = EMA(self.config['ema_gamma'], self.network) self.para_swapped = False def setup_ema(self): if self.config['ema_opt']: self.ema.setup() def update_ema(self): if self.config['ema_opt']: self.ema.update() def eval(self): if self.config['ema_opt']: self.ema.swap_parameters() self.para_swapped = True def train(self): if self.para_swapped: self.ema.swap_parameters() self.para_swapped = False def update(self, batch_meta, batch_data): self.network.train() labels = batch_data[batch_meta['label']] # print('data size:',batch_data[batch_meta['token_id']].size()) if batch_meta['pairwise']: labels = labels.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0] if self.config['cuda']: y = Variable(labels.cuda(async=True), requires_grad=False) else: y = Variable(labels, requires_grad=False) task_id = batch_meta['task_id'] task_type = batch_meta['task_type'] inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) # pdb.set_trace() logits = self.mnetwork(*inputs) if batch_meta['pairwise']: logits = logits.view(-1, batch_meta['pairwise_size']) # pdb.set_trace() if task_type > 0: if self.config['answer_relu']: logits = F.relu(logits) loss = F.mse_loss(logits.squeeze(1), y) else: loss = F.cross_entropy(logits, y) if self.config['mediqa_pairloss'] is not None and batch_meta[ 'dataset_name'] in mediqa_name_list: # print(logits) # print(batch_data[batch_meta['rank_label']].size()) # input('ha') logits = logits.squeeze().view(-1, 2) # print(batch_data[batch_meta['rank_label']]) rank_y = batch_data[batch_meta['rank_label']].view(-1, 2) # print(rank_y) if self.config['mediqa_pairloss'] == 'hinge': # print(logits) first_logit, second_logit = logits.split(1, dim=1) # print(first_logit,second_logit) # pdb.set_trace() rank_y = (2 * rank_y - 1).to(torch.float32) rank_y = rank_y[:, 0] pairwise_loss = F.margin_ranking_loss( first_logit.squeeze(1), second_logit.squeeze(1), rank_y, margin=self.config['hinge_lambda']) else: # pdb.set_trace() pairwise_loss = F.cross_entropy(logits, rank_y[:, 1]) # print('pairwise_loss:',pairwise_loss,'mse loss:',loss) loss += pairwise_loss self.train_loss.update(loss.item(), logits.size(0)) self.optimizer.zero_grad() loss.backward() if self.config['global_grad_clipping'] > 0: torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.config['global_grad_clipping']) self.optimizer.step() self.updates += 1 self.update_ema() def predict(self, batch_meta, batch_data): self.network.eval() task_id = batch_meta['task_id'] task_type = batch_meta['task_type'] inputs = batch_data[:batch_meta['input_len']] if len(inputs) == 3: inputs.append(None) inputs.append(None) inputs.append(task_id) score = self.mnetwork(*inputs) gold_label = batch_meta['label'] if batch_meta['pairwise']: score = score.contiguous().view(-1, batch_meta['pairwise_size']) if task_type < 1: score = F.softmax(score, dim=1) score = score.data.cpu() score = score.numpy() predict = np.zeros(score.shape, dtype=int) if task_type < 1: positive = np.argmax(score, axis=1) for idx, pos in enumerate(positive): predict[idx, pos] = 1 predict = predict.reshape(-1).tolist() score = score.reshape(-1).tolist() return score, predict, batch_meta['true_label'] else: if task_type < 1: score = F.softmax(score, dim=1) # pdb.set_trace() score = score.data.cpu() score = score.numpy() if task_type < 1: predict = np.argmax(score, axis=1).tolist() else: predict = np.greater( score, 2.0 + self.config['mediqa_score_offset']).astype(int) gold_label = np.greater( batch_meta['label'], 2.00001 + self.config['mediqa_score_offset']).astype(int) predict = predict.reshape(-1).tolist() gold_label = gold_label.reshape(-1).tolist() # print('predict:',predict,score) score = score.reshape(-1).tolist() return score, predict, gold_label def save(self, filename): network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()]) ema_state = dict([ (k, v.cpu()) for k, v in self.ema.model.state_dict().items() ]) if self.ema is not None else dict() params = { 'state': network_state, 'optimizer': self.optimizer.state_dict(), 'ema': ema_state, 'config': self.config, } torch.save(params, filename) logger.info('model saved to {}'.format(filename)) def cuda(self): self.network.cuda() if self.config['ema_opt']: self.ema.cuda()
def main(*_, **kwargs): use_cuda = torch.cuda.is_available() and kwargs["device"] >= 0 device = torch.device("cuda:" + str(kwargs["device"]) if use_cuda else "cpu") if use_cuda: torch.cuda.set_device(device) kwargs["use_cuda"] = use_cuda neptune.create_experiment( name="bert-span-parser", upload_source_files=[], params={ k: str(v) if isinstance(v, bool) else v for k, v in kwargs.items() }, ) logger.info("Settings: {}", json.dumps(kwargs, indent=2, ensure_ascii=False)) # For reproducibility os.environ["PYTHONHASHSEED"] = str(kwargs["seed"]) random.seed(kwargs["seed"]) np.random.seed(kwargs["seed"]) torch.manual_seed(kwargs["seed"]) torch.cuda.manual_seed_all(kwargs["seed"]) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Prepare and load data tokenizer = BertTokenizer.from_pretrained(kwargs["bert_model"], do_lower_case=False) logger.info("Loading data...") train_treebank = load_trees(kwargs["train_file"]) dev_treebank = load_trees(kwargs["dev_file"]) test_treebank = load_trees(kwargs["test_file"]) logger.info( "Loaded {:,} train, {:,} dev, and {:,} test examples!", len(train_treebank), len(dev_treebank), len(test_treebank), ) logger.info("Preprocessing data...") train_parse = [tree.convert() for tree in train_treebank] train_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in train_parse] dev_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in dev_treebank] test_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in test_treebank] logger.info("Data preprocessed!") logger.info("Preparing data for training...") tags = [] labels = [] for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, InternalParseNode): labels.append(node.label) nodes.extend(reversed(node.children)) else: tags.append(node.tag) tag_encoder = LabelEncoder() tag_encoder.fit(tags, reserved_labels=["[PAD]", "[UNK]"]) label_encoder = LabelEncoder() label_encoder.fit(labels, reserved_labels=[()]) logger.info("Data prepared!") # Settings num_train_optimization_steps = kwargs["num_epochs"] * ( (len(train_parse) - 1) // kwargs["batch_size"] + 1) kwargs["batch_size"] //= kwargs["gradient_accumulation_steps"] logger.info("Creating dataloaders for training...") train_dataloader, train_features = create_dataloader( sentences=train_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=False, ) dev_dataloader, dev_features = create_dataloader( sentences=dev_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=True, ) test_dataloader, test_features = create_dataloader( sentences=test_sentences, batch_size=kwargs["batch_size"], tag_encoder=tag_encoder, tokenizer=tokenizer, is_eval=True, ) logger.info("Dataloaders created!") # Initialize model model = ChartParser.from_pretrained( kwargs["bert_model"], tag_encoder=tag_encoder, label_encoder=label_encoder, lstm_layers=kwargs["lstm_layers"], lstm_dim=kwargs["lstm_dim"], tag_embedding_dim=kwargs["tag_embedding_dim"], label_hidden_dim=kwargs["label_hidden_dim"], dropout_prob=kwargs["dropout_prob"], ) model.to(device) # Prepare optimizer param_optimizers = list(model.named_parameters()) if kwargs["freeze_bert"]: for p in model.bert.parameters(): p.requires_grad = False param_optimizers = [(n, p) for n, p in param_optimizers if p.requires_grad] # Hack to remove pooler, which is not used thus it produce None grad that break apex param_optimizers = [n for n in param_optimizers if "pooler" not in n[0]] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizers if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in param_optimizers if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = BertAdam( optimizer_grouped_parameters, lr=kwargs["learning_rate"], warmup=kwargs["warmup_proportion"], t_total=num_train_optimization_steps, ) if kwargs["fp16"]: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") pretrained_model_file = os.path.join(kwargs["output_dir"], MODEL_FILENAME) if kwargs["do_eval"]: assert os.path.isfile( pretrained_model_file), "Pretrained model file does not exist!" logger.info("Loading pretrained model from {}", pretrained_model_file) # Load model from file params = torch.load(pretrained_model_file, map_location=device) model.load_state_dict(params["model"]) logger.info( "Loaded pretrained model (Epoch: {:,}, Fscore: {:.2f})", params["epoch"], params["fscore"], ) eval_score = eval( model=model, eval_dataloader=test_dataloader, eval_features=test_features, eval_trees=test_treebank, eval_sentences=test_sentences, tag_encoder=tag_encoder, device=device, ) neptune.send_metric("test_eval_precision", eval_score.precision()) neptune.send_metric("test_eval_recall", eval_score.recall()) neptune.send_metric("test_eval_fscore", eval_score.fscore()) tqdm.write("Evaluation score: {}".format(str(eval_score))) else: # Training phase global_steps = 0 start_epoch = 0 best_dev_fscore = 0 if kwargs["preload"] or kwargs["resume"]: assert os.path.isfile( pretrained_model_file), "Pretrained model file does not exist!" logger.info("Resuming model from {}", pretrained_model_file) # Load model from file params = torch.load(pretrained_model_file, map_location=device) model.load_state_dict(params["model"]) if kwargs["resume"]: optimizer.load_state_dict(params["optimizer"]) torch.cuda.set_rng_state_all([ state.cpu() for state in params["torch_cuda_random_state_all"] ]) torch.set_rng_state(params["torch_random_state"].cpu()) np.random.set_state(params["np_random_state"]) random.setstate(params["random_state"]) global_steps = params["global_steps"] start_epoch = params["epoch"] + 1 best_dev_fscore = params["fscore"] else: assert not os.path.isfile( pretrained_model_file ), "Please remove or move the pretrained model file to another place!" for epoch in trange(start_epoch, kwargs["num_epochs"], desc="Epoch"): model.train() train_loss = 0 num_train_steps = 0 for step, (indices, *_) in enumerate( tqdm(train_dataloader, desc="Iteration")): ids, attention_masks, tags, sections, trees, sentences = prepare_batch_input( indices=indices, features=train_features, trees=train_parse, sentences=train_sentences, tag_encoder=tag_encoder, device=device, ) loss = model( ids=ids, attention_masks=attention_masks, tags=tags, sections=sections, sentences=sentences, gold_trees=trees, ) if kwargs["gradient_accumulation_steps"] > 1: loss /= kwargs["gradient_accumulation_steps"] if kwargs["fp16"]: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_loss += loss.item() num_train_steps += 1 if (step + 1) % kwargs["gradient_accumulation_steps"] == 0: optimizer.step() optimizer.zero_grad() global_steps += 1 # Write logs neptune.send_metric("train_loss", epoch, train_loss / num_train_steps) neptune.send_metric("global_steps", epoch, global_steps) tqdm.write( "Epoch: {:,} - Train loss: {:.4f} - Global steps: {:,}".format( epoch, train_loss / num_train_steps, global_steps)) # Evaluate eval_score = eval( model=model, eval_dataloader=dev_dataloader, eval_features=dev_features, eval_trees=dev_treebank, eval_sentences=dev_sentences, tag_encoder=tag_encoder, device=device, ) neptune.send_metric("eval_precision", epoch, eval_score.precision()) neptune.send_metric("eval_recall", epoch, eval_score.recall()) neptune.send_metric("eval_fscore", epoch, eval_score.fscore()) tqdm.write("Epoch: {:,} - Evaluation score: {}".format( epoch, str(eval_score))) # Save best model if eval_score.fscore() > best_dev_fscore: best_dev_fscore = eval_score.fscore() tqdm.write("** Saving model...") os.makedirs(kwargs["output_dir"], exist_ok=True) torch.save( { "epoch": epoch, "global_steps": global_steps, "fscore": best_dev_fscore, "random_state": random.getstate(), "np_random_state": np.random.get_state(), "torch_random_state": torch.get_rng_state(), "torch_cuda_random_state_all": torch.cuda.get_rng_state_all(), "optimizer": optimizer.state_dict(), "model": (model.module if hasattr(model, "module") else model).state_dict(), }, pretrained_model_file, ) tqdm.write( "** Best evaluation fscore: {:.2f}".format(best_dev_fscore))