Example #1
0
    def __init__(self):
        self.args = parse_args()
        if self.args['mlm']:
            self.bert_model = BertForMaskedLM.from_pretrained(
                pretrained_bert_name, output_hidden_states=True)
        else:
            self.bert_model = BertForPreTraining.from_pretrained(
                pretrained_bert_name, output_hidden_states=True)

        #print(self.bert_model)
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            pretrained_bert_name)
        self.bert_out_dim = self.bert_model.bert.encoder.layer[
            11].output.dense.out_features
        self.args['bert_output_dim'] = self.bert_out_dim
        print("BERT output dim {}".format(self.bert_out_dim))

        self.ner_path = self.args['ner_train_file']
        self.ner_reader = DataReader(self.ner_path,
                                     "NER",
                                     tokenizer=self.bert_tokenizer,
                                     batch_size=30)
        self.args['ner_label_vocab'] = self.ner_reader.label_voc
        self.ner_head = NerModel(self.args)

        param_optimizer = list(self.bert_model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay_rate':
            0.0
        }]

        self.qas_head = QasModel(self.args)

        self.bert_optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=self.args['warmup_steps'],
            num_training_steps=self.args['t_total'])
Example #2
0
def inference(model_path, class_dict_path, args):
    with open(class_dict_path, "rb") as js:
        class_dict = json.load(js)

    save_folder = args.save_folder
    eval_save_path = os.path.join(save_folder, "conll_dev_out.txt")
    test_file_path = args.test_file_path

    model_tuple = (BertForTokenClassification, BertTokenizer,
                   "dmis-lab/biobert-v1.1", "BioBERT")
    model_name = model_tuple[2]
    tokenizer = BertTokenizer.from_pretrained(model_name)
    batch_size = args.batch_size

    eval_ner_dataset = NerDataset(test_file_path)
    eval_ner_dataset.label_vocab.set_w2ind(class_dict)
    test_dataset_loader = NerDatasetLoader(eval_ner_dataset,
                                           tokenizer,
                                           batch_size=batch_size)

    num_classes = len(class_dict)
    args.output_dim = num_classes

    model = NerModel(args, model_tuple)
    print("Loading weights from {}".format(model_path))
    load_model(model, model_path)
    model.to(device)
    pre, rec, f1, total_loss = evaluate(model, test_dataset_loader,
                                        eval_save_path)

    # Save result
    result_save_path = os.path.join(save_folder, "results.json")
    result = {
        "model_path": model_path,
        "precision": pre,
        "recall": rec,
        "test_loss": total_loss,
        "f1": f1
    }
    with open(result_save_path, "w") as j:
        json.dump(result, j)
Example #3
0
 def train_ner(self):
     self.ner_reader = DataReader(self.ner_path,
                                  "NER",
                                  tokenizer=self.bert_tokenizer,
                                  batch_size=30)
     self.args['ner_label_vocab'] = self.ner_reader.label_voc
     self.ner_head = NerModel(self.args)
     print("Starting training")
     for j in range(10):
         for i in range(10):
             self.ner_head.optimizer.zero_grad()
             tokens, bert_batch_after_padding, data = self.ner_reader[0]
             sent_lens, masks, tok_inds, ner_inds,\
                  bert_batch_ids,  bert_seq_ids, bert2toks, cap_inds = data
             outputs = self.bert_model(bert_batch_ids,
                                       token_type_ids=bert_seq_ids)
             #bert_hiddens = self._get_bert_batch_hidden(outputs[-1],bert2toks)
             #loss, out_logits =  self.ner_head(bert_hiddens,ner_inds)
             loss, out_logits = self.get_ner(outputs[-1], bert2toks,
                                             ner_inds)
             #print(loss.item())
             loss.backward()
             self.ner_head.optimizer.step()
         self.eval_ner()
    #dataset = {}
    #dataset['labels_list'] = get_labels_list("dataset/{}/tag.dict".format(modelname))
    #print("labels_list: {}".format(dataset['labels_list']))

    if mode == "train":
        dataset = NerDataset([modelname]).as_dict()
    else:
        dataset = NerDataset([modelname], labels_only=True).as_dict()

    if mode == "train":
        #dataset['train'] = NerPartDataset("dataset/{}/train.txt".format(modelname)).to_dataframe()
        #dataset['val'] = NerPartDataset("dataset/{}/valid.txt".format(modelname)).to_dataframe()
        #dataset['test'] = NerPartDataset("dataset/{}/test.txt".format(modelname)).to_dataframe()
        #dataset = NerDataset([modelname]).as_dict()
        model = NerModel(modelname=modelname, dataset=dataset)
        model.train()
        model.eval()

    if mode in {"infer"}:
        model = NerModel(modelname=modelname,
                         dataset=dataset,
                         use_saved_model=True)

    if mode in {"train", "infer"}:
        result = model.predict(test_sentences)
        print("result:", result)

    print("\nManually input")
    while True:
        input_text = input("Input text: ")
Example #5
0
        "Enter in username box after clicking on Submit button",
        "Enter in abstract before clicking on Submit button",
    ]

    if mode == "train":
        dataset = NerDataset(complex_dataset_names).as_dict()
    else:
        dataset = NerDataset(complex_dataset_names).as_dict()
        #dataset = NerDataset(complex_dataset_names, labels_only=True).as_dict()
        #dataset['labels_list'] = get_labels_list("dataset/{}/tag.dict".format(modelname))

    if mode == "train":
        if pretrained_type == "LM":
            model = NerModel(
                modelname=modelname,
                dataset=dataset,
                input_dir="lm_outputs_test/from_scratch/best_model",
                output_dir=output_dir)
        elif pretrained_type == "continue":
            model = NerModel(modelname=modelname,
                             dataset=dataset,
                             input_dir=output_dir,
                             output_dir=output_dir)
        elif pretrained_type == "English":
            model = NerModel(modelname=modelname,
                             dataset=dataset,
                             output_dir=output_dir)

        training_details = model.train()

    elif mode in {"test", "infer"}:
Example #6
0
def train(args):
    #     biobert_model_tuple = MODELS[-1]
    # model_tuple = (BertModel, BertTokenizer, "bert-base-uncased", "Bert-base")
    model_tuple = (BertForTokenClassification, BertTokenizer,
                   "dmis-lab/biobert-v1.1", "BioBERT")
    dataset_loaders = {}

    save_folder = args.save_folder
    if not os.path.isdir(save_folder): os.makedirs(save_folder)
    size = args.size
    batch_size = args.batch_size

    target_dataset_path = args.target_dataset_path

    dev_file_path = os.path.join(target_dataset_path, "ent_devel.tsv")
    test_file_path = os.path.join(target_dataset_path, "ent_test.tsv")
    train_file_path = args.train_file_path if not args.dev_only else dev_file_path

    target_dataset = os.path.split(target_dataset_path)[-1]
    train_dataset_name = os.path.split(os.path.split(train_file_path)[0])[-1]

    print("Target dataset: {}\nTrain {} dev {} test {}...\n".format(
        target_dataset, train_file_path, dev_file_path, test_file_path))
    model_name = model_tuple[2]
    tokenizer = BertTokenizer.from_pretrained(model_name)

    ner_dataset = NerDataset(train_file_path, size=size)
    dataset_loader = NerDatasetLoader(ner_dataset,
                                      tokenizer,
                                      batch_size=batch_size)
    dataset_loaders["train"] = dataset_loader
    num_classes = len(dataset_loader.dataset.label_vocab)
    args.output_dim = num_classes
    print("Label vocab: {}".format(ner_dataset.label_vocab.w2ind))
    eval_ner_dataset = NerDataset(dev_file_path, size=size)
    eval_ner_dataset.label_vocab = ner_dataset.label_vocab
    eval_ner_dataset.token_vocab = ner_dataset.token_vocab
    eval_dataset_loader = NerDatasetLoader(eval_ner_dataset,
                                           tokenizer,
                                           batch_size=batch_size)
    dataset_loaders["devel"] = eval_dataset_loader

    test_ner_dataset = NerDataset(test_file_path, size=size)
    test_ner_dataset.label_vocab = ner_dataset.label_vocab
    test_ner_dataset.token_vocab = ner_dataset.token_vocab
    test_dataset_loader = NerDatasetLoader(test_ner_dataset,
                                           tokenizer,
                                           batch_size=batch_size)
    dataset_loaders["test"] = test_dataset_loader

    model = NerModel(args, model_tuple)
    if args.load_dc_model:
        load_path = args.dc_model_weight_path
        print("Loading model from : {}".format(load_path))
        dc_weights = torch.load(load_path)
        model = load_weights_with_skip(model, dc_weights)
    trained_model, train_result, class_to_idx = train_model(
        model, dataset_loaders, save_folder, args)

    # Plot train/dev losses
    plot_save_path = os.path.join(save_folder, "loss_plot.png")
    plot_arrays([train_result["train_losses"], train_result["dev_losses"]],
                ["train", "dev"], "epochs", 'loss', plot_save_path)

    # Evaluate on test_set
    save_path = os.path.join(save_folder, "conll_testout.txt")
    test_pre, test_rec, test_f1, test_loss = evaluate(trained_model,
                                                      dataset_loaders["test"],
                                                      save_path)

    # Save result
    result_save_path = os.path.join(save_folder, "results.json")
    result = {
        "model_name": model_name,
        "train_size": len(ner_dataset),
        "target_dataset": target_dataset,
        "precision": test_pre,
        "recall": test_rec,
        "test_loss": test_loss,
        "train_dataset_name": train_dataset_name,
        "f1": test_f1,
        "train_result": train_result
    }
    with open(result_save_path, "w") as j:
        json.dump(result, j)

    class_to_idx_path = os.path.join(save_folder, "class_to_idx.json")
    with open(class_to_idx_path, "w") as j:
        json.dump(class_to_idx, j)
    return result
Example #7
0
class BioMLT():
    def __init__(self):
        self.args = parse_args()
        if self.args['mlm']:
            self.bert_model = BertForMaskedLM.from_pretrained(
                pretrained_bert_name, output_hidden_states=True)
        else:
            self.bert_model = BertForPreTraining.from_pretrained(
                pretrained_bert_name, output_hidden_states=True)

        #print(self.bert_model)
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            pretrained_bert_name)
        self.bert_out_dim = self.bert_model.bert.encoder.layer[
            11].output.dense.out_features
        self.args['bert_output_dim'] = self.bert_out_dim
        print("BERT output dim {}".format(self.bert_out_dim))

        self.ner_path = self.args['ner_train_file']
        self.ner_reader = DataReader(self.ner_path,
                                     "NER",
                                     tokenizer=self.bert_tokenizer,
                                     batch_size=30)
        self.args['ner_label_vocab'] = self.ner_reader.label_voc
        self.ner_head = NerModel(self.args)

        param_optimizer = list(self.bert_model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay_rate':
            0.001
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay_rate':
            0.0
        }]

        self.qas_head = QasModel(self.args)

        self.bert_optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=self.args['warmup_steps'],
            num_training_steps=self.args['t_total'])

    ## We are now averaging over the bert layer outputs for the NER task
    ## We may want to do this for QAS as well?
    ## This is very slow, right?
    def _get_bert_batch_hidden(self, hiddens, bert2toks, layers=[-2, -3, -4]):
        meanss = torch.mean(torch.stack([hiddens[i] for i in layers]), 0)
        batch_my_hiddens = []
        for means, bert2tok in zip(meanss, bert2toks):
            my_token_hids = []
            my_hiddens = []
            for i, b2t in enumerate(bert2tok):
                if i > 0 and b2t != bert2tok[i - 1]:
                    my_hiddens.append(torch.mean(torch.stack(my_token_hids),
                                                 0))
                    #my_token_hids = [means[i+1]]  ## we skip the CLS token
                    my_token_hids = [means[i]
                                     ]  ## Now I dont skip the CLS token
                else:
                    #my_token_hids.append(means[i+1])
                    my_token_hids = [means[i]
                                     ]  ## Now I dont skip the CLS token
            my_hiddens.append(torch.mean(torch.stack(my_token_hids), 0))
            sent_hiddens = torch.stack(my_hiddens)
            batch_my_hiddens.append(sent_hiddens)
        return torch.stack(batch_my_hiddens)

    def _get_token_to_bert_predictions(self, predictions, bert2toks):
        #logging.info("Predictions shape {}".format(predictions.shape))

        #logging.info("Bert2toks shape {}".format(bert2toks.shape))
        bert_predictions = []
        for pred, b2t in zip(predictions, bert2toks):
            bert_preds = []
            for b in b2t:
                bert_preds.append(pred[b])
            stack = torch.stack(bert_preds)
            bert_predictions.append(stack)
        stackk = torch.stack(bert_predictions)
        return stackk

    def _get_squad_bert_batch_hidden(self, hiddens, layers=[-2, -3, -4]):
        return torch.mean(torch.stack([hiddens[i] for i in layers]), 0)

    def _get_squad_to_ner_bert_batch_hidden(self,
                                            hiddens,
                                            bert2toks,
                                            layers=[-2, -3, -4],
                                            device='cpu'):
        pad_size = hiddens[-1].shape[1]
        hidden_dim = hiddens[-1].shape[2]
        pad_vector = torch.tensor([0.0 for i in range(hidden_dim)]).to(device)
        print(pad_vector.device)
        meanss = torch.mean(torch.stack([hiddens[i] for i in layers]), 0)
        batch_my_hiddens = []
        batch_lens = []
        for means, bert2tok in zip(meanss, bert2toks):
            my_token_hids = []
            my_hiddens = []
            for i, b2t in enumerate(bert2tok):
                if i > 0 and b2t != bert2tok[i - 1]:
                    my_hiddens.append(torch.mean(torch.stack(my_token_hids),
                                                 0))
                    #my_token_hids = [means[i+1]]  ## we skip the CLS token
                    my_token_hids = [means[i]
                                     ]  ## Now I dont skip the CLS token
                else:
                    #my_token_hids.append(means[i+1])
                    my_token_hids = [means[i]
                                     ]  ## Now I dont skip the CLS token
            my_hiddens.append(torch.mean(torch.stack(my_token_hids), 0))
            batch_lens.append(len(my_hiddens))
            for i in range(pad_size - len(my_hiddens)):
                my_hiddens.append(pad_vector)
            sent_hiddens = torch.stack(my_hiddens)
            batch_my_hiddens.append(sent_hiddens)
        #for sent_hidden in batch_my_hiddens:
        #logging.info("Squad squeezed sent shape {}".format(sent_hidden.shape))
        return torch.stack(batch_my_hiddens), torch.tensor(batch_lens)

    def load_model(self):
        if self.args['mlm']:
            logging.info("Attempting to load  model from {}".format(
                self.args['output_dir']))
            self.bert_model = BertForMaskedLM.from_pretrained(
                self.args['output_dir'])
        else:
            self.bert_model = BertForPreTraining.from_pretrained(
                self.args['output_dir'])
        self.bert_tokenizer = BertTokenizer.from_pretrained(
            self.args['output_dir'])
        sch_path = os.path.join(self.args['output_dir'], "scheduler.pt")
        opt_path = os.path.join(self.args['output_dir'], "optimizer.pt")
        if os.path.isfile(sch_path) and os.path.isfile(opt_path):
            self.bert_optimizer.load_state_dict(torch.load(opt_path))
            self.bert_scheduler.load_state_dict(torch.load(sch_path))
        logging.info("Could not load model from {}".format(
            self.args['output_dir']))
        logging.info(
            "Initializing Masked LM from {} ".format(pretrained_bert_name))
        #self.bert_model = BertForMaskedLM.from_pretrained(pretrained_bert_name)
        #self.bert_model = BertForPreTraining.from_pretrained(pretrained_bert_name)
    def save_model(self):
        out_dir = self.args['output_dir']
        if not os.path.isdir(out_dir):
            os.makedirs(out_dir)
        print("Saving model checkpoint to {}".format(out_dir))
        logger.info("Saving model checkpoint to {}".format(out_dir))
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = self.bert_model
        model_to_save.save_pretrained(out_dir)
        self.bert_tokenizer.save_pretrained(out_dir)
        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(out_dir, "training_args.bin"))
        torch.save(self.bert_optimizer.state_dict(),
                   os.path.join(out_dir, "optimizer.pt"))
        torch.save(self.bert_scheduler.state_dict(),
                   os.path.join(out_dir, "scheduler.pt"))

    def predict_qas(self, batch):
        ## batch_size = 1
        if len(batch[0].shape) == 1:
            batch = tuple(t.unsqueeze_(0) for t in batch)

        ## Not sure if these update in-place?>?>?> Have to check betweenn pytorch versions
        self.bert_model.eval()
        self.qas_head.eval()
        squad_inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "start_positions": batch[3],
            "end_positions": batch[4],
        }
        bert_inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
        }
        with torch.no_grad():
            outputs = self.bert_model(**bert_inputs)
            squad_inputs["bert_outputs"] = outputs[-1][-2]
            start_pred, end_pred = self.qas_head.predict(**squad_inputs)
            length = torch.sum(batch[1])
            tokens = self.bert_tokenizer.convert_ids_to_tokens(
                batch[0].squeeze(0).detach().cpu().numpy()[:length])
        logging.info("Example {}".format(tokens))
        logging.info("Answer {}".format(tokens[start_pred:end_pred + 1]))
        logging.info("Start Pred {}  start truth {}".format(
            start_pred, squad_inputs["start_positions"]))
        logging.info("End Pred {}  end truth {}".format(
            end_pred, squad_inputs["end_positions"]))

    def get_qas(self, bert_output, batch):

        #batch = tuple(t.unsqueeze_(0) for t in batch)
        squad_inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "start_positions": batch[3],
            "end_positions": batch[4],
        }
        squad_inputs["bert_outputs"] = bert_output
        qas_outputs = self.qas_head(**squad_inputs)
        #print(qas_outputs[0].item())
        #qas_outputs[0].backward()
        #self.bert_optimizer.step()
        #self.qas_head.optimizer.step()
        return qas_outputs

    def get_ner(self, bert_output, bert2toks, ner_inds):
        bert_hiddens = self._get_bert_batch_hidden(bert_output, bert2toks)
        loss, out_logits = self.ner_head(bert_hiddens, ner_inds)
        #logging.info("NER loss {} ".format(loss.item()))
        return (loss, out_logits)

    ## training a flat model (multi-task learning hard-sharing)
    def train_qas_ner(self):
        device = self.args['device']
        self.device = device
        args = hugging_parse_args()
        qas_train_dataset = squad_load_and_cache_examples(
            args, self.bert_tokenizer)
        print("Size of the dataset {}".format(len(qas_train_dataset)))
        args.train_batch_size = self.args['batch_size']
        qas_train_sampler = RandomSampler(qas_train_dataset)
        qas_train_dataloader = DataLoader(qas_train_dataset,
                                          sampler=qas_train_sampler,
                                          batch_size=args.train_batch_size)
        t_totals = len(
            qas_train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        #self.train_ner()
        epochs_trained = 0

        train_iterator = trange(epochs_trained,
                                int(args.num_train_epochs),
                                desc="Epoch")
        # Added here for reproductibility
        self.bert_model.to(device)
        self.qas_head.to(device)
        self.ner_head.to(device)

        for _ in train_iterator:
            epoch_iterator = tqdm(qas_train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.bert_optimizer.zero_grad()
                self.qas_head.optimizer.zero_grad()
                self.ner_head.optimizer.zero_grad()
                batch = qas_train_dataset[0]
                batch = tuple(t.unsqueeze(0) for t in batch)
                tokens, bert_batch_after_padding, data = self.ner_reader[0]
                #logging.info(batch[-1])
                self.bert_model.train()
                # logging.info(self.bert_tokenizer.convert_ids_to_tokens(batch[0][0].detach().numpy()))
                batch = tuple(t.to(device) for t in batch)
                data = [d.to(device) for d in data]
                bert_inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                }
                self.predict_ner()
                #logging.info("Input ids shape : {}".format(batch[0].shape))
                sent_lens, masks, tok_inds, ner_inds, \
                bert_batch_ids, bert_seq_ids, bert2toks, cap_inds = data
                outputs = self.bert_model(bert_batch_ids,
                                          token_type_ids=bert_seq_ids)
                # bert_hiddens = self._get_bert_batch_hidden(outputs[-1],bert2toks)
                # loss, out_logits =  self.ner_head(bert_hiddens,ner_inds)
                ner_loss, ner_out_logits = self.get_ner(
                    outputs[-1], bert2toks, ner_inds)
                outputs = self.bert_model(**bert_inputs)
                bert_outs_for_ner, lens = self._get_squad_to_ner_bert_batch_hidden(
                    outputs[-1], batch[-1], device=device)
                ner_outs = self.ner_head(bert_outs_for_ner)
                ner_outs_for_qas = self._get_token_to_bert_predictions(
                    ner_outs, batch[-1])
                logging.info("BERT OUTS FOR QAS {}".format(
                    ner_outs_for_qas.shape))
                bert_out = self._get_squad_bert_batch_hidden(outputs[-1])
                #logging.info("Bert out shape {}".format(bert_out.shape))
                qas_outputs = self.get_qas(bert_out, batch)
                # qas_outputs = self.qas_head(**squad_inputs)
                # print(qas_outputs[0].item())
                total_loss = ner_loss + qas_outputs[0]
                total_loss.backward()
                logging.info("TOtal loss {} {}  {}".format(
                    total_loss.item(), ner_loss.item(), qas_outputs[0].item()))
                self.ner_head.optimizer.step()
                self.bert_optimizer.step()
                self.qas_head.optimizer.step()
            self.predict_qas(batch)
            self.predict_ner()

    def predict_ner(self):
        self.eval_ner()

    def train_qas(self):
        device = self.args['device']
        args = hugging_parse_args()
        train_dataset = squad_load_and_cache_examples(args,
                                                      self.bert_tokenizer)
        print("Size of the dataset {}".format(len(train_dataset)))
        args.train_batch_size = self.args['batch_size']
        #train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        t_totals = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [{
            "params": self.qas_head.parameters(),
            "weight_decay": 0.0
        }]
        #self.bert_squad_optimizer =AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

        ## Scheduler for sub-components
        #scheduler = get_linear_schedule_with_warmup(
        #self.bert_squad_optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_totals)
        tr_loss, logging_loss = 0.0, 0.0
        steps_trained_in_current_epoch = 0
        epochs_trained = 0

        train_iterator = trange(epochs_trained,
                                int(args.num_train_epochs),
                                desc="Epoch")
        # Added here for reproductibility
        self.bert_model.to(device)
        self.qas_head.to(device)
        for _ in train_iterator:
            epoch_iterator = tqdm(qas_train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.bert_optimizer.zero_grad()
                self.qas_head.optimizer.zero_grad()
                batch = train_dataset[0]
                batch = tuple(t.unsqueeze(0) for t in batch)
                logging.info(batch[-1])
                self.bert_model.train()
                #logging.info(self.bert_tokenizer.convert_ids_to_tokens(batch[0][0].detach().numpy()))
                batch = tuple(t.to(device) for t in batch)
                bert_inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                }
                logging.info("Input ids shape : {}".format(batch[0].shape))
                outputs = self.bert_model(**bert_inputs)
                bert_outs_for_ner, lens = self._get_squad_to_ner_bert_batch_hidden(
                    outputs[-1], batch[-1], device=device)
                ner_outs = self.ner_head(bert_outs_for_ner)
                ner_outs_for_qas = self._get_token_to_bert_predictions(
                    ner_outs, batch[-1])
                logging.info("BERT OUTS FOR NER {}".format(
                    ner_outs_for_qas.shape))
                bert_out = self._get_squad_bert_batch_hidden(outputs[-1])
                logging.info("Bert out shape {}".format(bert_out.shape))
                qas_outputs = self.get_qas(bert_out, batch)
                #qas_outputs = self.qas_head(**squad_inputs)
                #print(qas_outputs[0].item())
                qas_outputs[0].backward()
                logging.info("Loss {}".format(qas_outputs[0].item()))
                self.bert_optimizer.step()
                self.qas_head.optimizer.step()
            self.predict_qas(batch)

    def pretrain_mlm(self):
        device = self.args['device']
        epochs_trained = 0
        epoch_num = self.args['epoch_num']
        batch_size = self.args['batch_size']
        block_size = self.args['block_size']
        huggins_args = hugging_parse_args()
        self.huggins_args = huggins_args
        #file_list = pubmed_files()
        file_list = ["PMC6958785.txt", "PMC6961255.txt"]
        train_dataset = MyTextDataset(self.bert_tokenizer,
                                      huggins_args,
                                      file_list,
                                      block_size=block_size)
        print("Dataset size {} ".format(len(train_dataset)))
        print(train_dataset[0])
        train_sampler = RandomSampler(train_dataset)

        def collate(examples):
            return pad_sequence(examples,
                                batch_first=True,
                                padding_value=self.bert_tokenizer.pad_token_id)

        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=batch_size,
                                      collate_fn=collate)
        t_totals = len(train_dataloader) // self.args['epoch_num']
        #self.dataset = reader.create_training_instances(file_list,bert_tokenizer)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=self.args['warmup_steps'],
            num_training_steps=t_totals)
        if self.args['load_model']:
            print("Model loaded")
            self.load_model()
        print("BERT after loading weights")
        print(self.bert_model.bert.encoder.layer[11].output.dense.weight)
        self.bert_model.to(device)
        self.bert_model.train()
        print("Model is being trained on {} ".format(
            next(self.bert_model.parameters()).device))
        train_iterator = trange(
            #epochs_trained, int(huggins_args.num_train_epochs), desc="Epoch")
            epochs_trained,
            int(epoch_num),
            desc="Epoch")
        #set_seed(args)  # Added here for reproducibility
        for _ in train_iterator:
            for step, batch in enumerate(epoch_iterator):
                #print("Batch shape {} ".format(batch.shape))
                #print("First input {} ".format(batch[0]))
                self.bert_optimizer.zero_grad(
                )  ## update mask_tokens to apply curriculum learnning!!!!
                inputs, labels = mask_tokens(batch, self.bert_tokenizer,
                                             huggins_args)
                tokens = self.bert_tokenizer.convert_ids_to_tokens(
                    inputs.cpu().detach().numpy()[0, :])
                label_tokens = self.bert_tokenizer.convert_ids_to_tokens(
                    labels.cpu().detach().numpy()[0, :])
                logging.info("Tokens {}".format(tokens))
                logging.info("Labels ".format(label_tokens))
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = self.bert_model(inputs, masked_lm_labels=labels)
                loss = outputs[0]
                logging.info("Loss obtained for batch of {} is {} ".format(
                    batch.shape, loss.item()))
                loss.backward()
                self.bert_optimizer.step()
                if step == 2:
                    break
            self.save_model()
            logging.info("Training is finished moving to evaluation")
            self.mlm_evaluate()

    def mlm_evaluate(self, prefix=""):
        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_output_dir = out_dir = self.args['output_dir']

        #eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)

        eval_batch_size = 1
        file_list = ["PMC6958785.txt"]
        eval_dataset = MyTextDataset(self.bert_tokenizer,
                                     self.huggins_args,
                                     file_list,
                                     block_size=128)

        #args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly

        def collate(examples):
            return pad_sequence(examples,
                                batch_first=True,
                                padding_value=self.bert_tokenizer.pad_token_id)

        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=eval_batch_size,
                                     collate_fn=collate)

        # multi-gpu evaluate
        #if args.n_gpu > 1:
        #    model = torch.nn.DataParallel(model)

        # Eval!
        model = self.bert_model
        logger.info("***** Running evaluation on {} *****".format(file_list))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            inputs, labels = mask_tokens(batch, self.bert_tokenizer,
                                         self.huggins_args)
            inputs = inputs.to(self.args['device'])
            labels = labels.to(self.args['device'])

            with torch.no_grad():
                outputs = model(inputs, masked_lm_labels=labels)
                lm_loss = outputs[0]
                eval_loss += lm_loss.mean().item()
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        perplexity = torch.exp(torch.tensor(eval_loss))

        result = {"perplexity": perplexity}

        output_eval_file = os.path.join(eval_output_dir, prefix,
                                        "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        return result

    ##parallel reading not implemented for training
    def pretrain(self):
        file_list = ["PMC6961255.txt"]
        reader = BertPretrainReader(file_list, self.bert_tokenizer)
        #dataset = reader.create_training_instances(file_list,self.bert_tokenizer)
        tokens = reader.dataset[1].tokens
        logging.info(tokens)
        input_ids = torch.tensor(
            self.bert_tokenizer.convert_tokens_to_ids(tokens)).unsqueeze(
                0)  # Batch size 1
        #token_type_ids= torch.tensor(dataset[1].segment_ids).unsqueeze(0)
        #print(input_ids.shape)
        #print(dataset[1].segment_ids)
        #next_label = torch.tensor([ 0 if dataset[1].is_random_next  else  1])
        token_ids, mask_labels, next_label, token_type_ids = reader[0]
        loss_fct = CrossEntropyLoss(ignore_index=-100)

        for i in range(10):
            self.bert_optimizer.zero_grad()
            #print("Input shape {}".format(token_ids.shape))
            outputs = self.bert_model(token_ids, token_type_ids=token_type_ids)
            prediction_scores, seq_relationship_scores = outputs[:2]
            vocab_dim = prediction_scores.shape[-1]
            masked_lm_loss = loss_fct(prediction_scores.view(-1, vocab_dim),
                                      mask_labels.view(-1))
            next_sent_loss = loss_fct(seq_relationship_scores.view(-1, 2),
                                      next_label.view(-1))
            loss = masked_lm_loss + next_sent_loss
            loss.backward()
            self.bert_optimizer.step()
        pred_tokens = self.bert_tokenizer.convert_ids_to_tokens(
            torch.argmax(prediction_scores, dim=2).detach().cpu().numpy()[0])
        logging.info("{} {} ".format("Real tokens", tokens))
        logging.info("{} {} ".format("Predictions", pred_tokens))

    def train_ner(self):
        self.ner_reader = DataReader(self.ner_path,
                                     "NER",
                                     tokenizer=self.bert_tokenizer,
                                     batch_size=30)
        self.args['ner_label_vocab'] = self.ner_reader.label_voc
        self.ner_head = NerModel(self.args)
        print("Starting training")
        for j in range(10):
            for i in range(10):
                self.ner_head.optimizer.zero_grad()
                tokens, bert_batch_after_padding, data = self.ner_reader[0]
                sent_lens, masks, tok_inds, ner_inds,\
                     bert_batch_ids,  bert_seq_ids, bert2toks, cap_inds = data
                outputs = self.bert_model(bert_batch_ids,
                                          token_type_ids=bert_seq_ids)
                #bert_hiddens = self._get_bert_batch_hidden(outputs[-1],bert2toks)
                #loss, out_logits =  self.ner_head(bert_hiddens,ner_inds)
                loss, out_logits = self.get_ner(outputs[-1], bert2toks,
                                                ner_inds)
                #print(loss.item())
                loss.backward()
                self.ner_head.optimizer.step()
            self.eval_ner()

    def eval_ner(self):
        tokens, bert_batch_after_padding, data = self.ner_reader[0]
        data = [d.to(self.device) for d in data]
        sent_lens, masks, tok_inds, ner_inds,\
             bert_batch_ids,  bert_seq_ids, bert2toks, cap_inds = data
        outputs = self.bert_model(bert_batch_ids, token_type_ids=bert_seq_ids)
        #bert_hiddens = self._get_bert_batch_hidden(outputs[-1],bert2toks)
        #loss, out_logits =  self.ner_head(bert_hiddens,ner_inds)
        loss, out_logits = self.get_ner(outputs[-1], bert2toks, ner_inds)
        tokens = tokens[0]
        logging.info("Tokens")
        logging.info(tokens)
        logging.info("NER INDS SHAPE {} ".format(ner_inds.shape))
        voc_size = len(self.ner_reader.label_voc)
        preds = torch.argmax(
            out_logits,
            dim=2).detach().cpu().numpy()[0, :len(tokens)] // voc_size
        ner_inds = ner_inds.detach().cpu().numpy()[0, :len(tokens)] // voc_size
        #logging.info("NER INDS {}".format(ner_inds))
        preds = self.ner_reader.label_voc.unmap(preds)
        ner_inds = self.ner_reader.label_voc.unmap(ner_inds)
        logging.info("Predictions {} \n Truth {} ".format(preds, ner_inds))