def main():
    # Read all the data instances
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data, subtasks_list = get_multitask_instances_for_valid_tasks(
        task_instances_dict, tag_statistics)

    if args.retrain:
        logging.info("Creating and training the model from 'bert-base-cased' ")
        # Create the save_directory if not exists
        make_dir_if_not_exists(args.save_directory)

        # Initialize tokenizer and model with pretrained weights
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        config = BertConfig.from_pretrained('bert-base-cased')
        config.subtasks = subtasks_list
        # print(config)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            'bert-base-cased', config=config)

        # Add new tokens in tokenizer
        new_special_tokens_dict = {
            "additional_special_tokens": ["<E>", "</E>", "<URL>", "@USER"]
        }
        # new_special_tokens_dict = {"additional_special_tokens": ["<E>", "</E>"]}
        tokenizer.add_special_tokens(new_special_tokens_dict)

        # Add the new embeddings in the weights
        print("Embeddings type:",
              model.bert.embeddings.word_embeddings.weight.data.type())
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        embedding_size = model.bert.embeddings.word_embeddings.weight.size(1)
        new_embeddings = torch.FloatTensor(
            len(new_special_tokens_dict["additional_special_tokens"]),
            embedding_size).uniform_(-0.1, 0.1)
        # new_embeddings = torch.FloatTensor(2, embedding_size).uniform_(-0.1, 0.1)
        print("new_embeddings shape:", new_embeddings.size())
        new_embedding_weight = torch.cat(
            (model.bert.embeddings.word_embeddings.weight.data,
             new_embeddings), 0)
        model.bert.embeddings.word_embeddings.weight.data = new_embedding_weight
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        # Update model config vocab size
        model.config.vocab_size = model.config.vocab_size + len(
            new_special_tokens_dict["additional_special_tokens"])
    else:
        # Load the tokenizer and model from the save_directory
        tokenizer = BertTokenizer.from_pretrained(args.save_directory)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            args.save_directory)
        # print(model.state_dict().keys())
        # TODO save and load the subtask classifier weights separately
        # Load from individual state dicts
        for subtask in model.subtasks:
            model.classifiers[subtask].load_state_dict(
                torch.load(
                    os.path.join(args.save_directory,
                                 f"{subtask}_classifier.bin")))
        # print(model.config)
        # exit()
    model.to(device)
    # Explicitly move the classifiers to device
    for subtask, classifier in model.classifiers.items():
        classifier.to(device)
    entity_start_token_id = tokenizer.convert_tokens_to_ids(["<E>"])[0]

    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data, test_data = split_multitask_instances_in_train_dev_test(
        data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_subtasks_train_size, neg_subtasks_train_size = log_multitask_data_statistics(
        train_data, model.subtasks)
    logging.info("Dev Data:")
    total_dev_size, pos_subtasks_dev_size, neg_subtasks_dev_size = log_multitask_data_statistics(
        dev_data, model.subtasks)
    logging.info("Test Data:")
    total_test_size, pos_subtasks_test_size, neg_subtasks_test_size = log_multitask_data_statistics(
        test_data, model.subtasks)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_subtasks_train_size,
        "neg": neg_subtasks_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_subtasks_dev_size,
        "neg": neg_subtasks_dev_size
    }
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_subtasks_test_size,
        "neg": neg_subtasks_test_size
    }

    # Extract subtasks data for dev and test
    dev_subtasks_data = split_data_based_on_subtasks(dev_data, model.subtasks)
    test_subtasks_data = split_data_based_on_subtasks(test_data,
                                                      model.subtasks)

    # Load the instances into pytorch dataset
    train_dataset = COVID19TaskDataset(train_data)
    dev_dataset = COVID19TaskDataset(dev_data)
    test_dataset = COVID19TaskDataset(test_data)
    logging.info("Loaded the datasets into Pytorch datasets")

    tokenize_collator = TokenizeCollator(tokenizer, model.subtasks,
                                         entity_start_token_id)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=POSSIBLE_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=tokenize_collator)
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=POSSIBLE_BATCH_SIZE,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=tokenize_collator)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=POSSIBLE_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=tokenize_collator)
    logging.info("Created train and test dataloaders with batch aggregation")

    # Only retrain if needed
    if args.retrain:
        print('DO RETRAIN')
        ##################################################################################################
        # NOTE: Training Tutorial Reference
        # https://mccormickml.com/2019/07/22/BERT-fine-tuning/#41-bertforsequenceclassification
        ##################################################################################################

        # Create an optimizer training schedule for the BERT text classification model
        # NOTE: AdamW is a class from the huggingface library (as opposed to pytorch)
        # I believe the 'W' stands for 'Weight Decay fix"
        # Recommended Schedule for BERT fine-tuning as per the paper
        # Batch size: 16, 32
        # Learning rate (Adam): 5e-5, 3e-5, 2e-5
        # Number of epochs: 2, 3, 4
        optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
        logging.info("Created model optimizer")
        # Number of training epochs. The BERT authors recommend between 2 and 4.
        # We chose to run for 4, but we'll see later that this may be over-fitting the
        # training data.
        epochs = args.n_epochs

        # Total number of training steps is [number of batches] x [number of epochs].
        # (Note that this is not the same as the number of training samples).
        total_steps = len(train_dataloader) * epochs

        # Create the learning rate scheduler.
        # NOTE: num_warmup_steps = 0 is the Default value in run_glue.py
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=total_steps)
        # We'll store a number of quantities such as training and validation loss,
        # validation accuracy, and timings.
        training_stats = []

        logging.info(f"Initiating training loop for {args.n_epochs} epochs...")
        # Measure the total training time for the whole run.
        total_start_time = time.time()

        # Find the accumulation steps
        accumulation_steps = args.batch_size / POSSIBLE_BATCH_SIZE

        # Loss trajectory for epochs
        epoch_train_loss = list()
        # Dev validation trajectory
        dev_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        for epoch in range(epochs):
            pbar = tqdm(train_dataloader)
            logging.info(f"Initiating Epoch {epoch+1}:")
            # Reset the total loss for each epoch.
            total_train_loss = 0
            train_loss_trajectory = list()

            # Reset timer for each epoch
            start_time = time.time()
            model.train()

            dev_log_frequency = 5
            n_steps = len(train_dataloader)
            dev_steps = int(n_steps / dev_log_frequency)
            for step, batch in enumerate(pbar):
                # Upload labels of each subtask to device
                for subtask in model.subtasks:
                    subtask_labels = batch["gold_labels"][subtask]
                    subtask_labels = subtask_labels.to(device)
                    # print("HAHAHAHAH:", subtask_labels.is_cuda)
                    batch["gold_labels"][subtask] = subtask_labels
                    # print("HAHAHAHAH:", batch["gold_labels"][subtask].is_cuda)
                # Forward
                input_dict = {
                    "input_ids":
                    batch["input_ids"].to(device),
                    "entity_start_positions":
                    batch["entity_start_positions"].to(device),
                    "labels":
                    batch["gold_labels"]
                }

                input_ids = batch["input_ids"]
                entity_start_positions = batch["entity_start_positions"]
                gold_labels = batch["gold_labels"]
                batch_data = batch["batch_data"]
                loss, logits = model(**input_dict)
                # loss = loss / accumulation_steps
                # Accumulate loss
                total_train_loss += loss.item()

                # Backward: compute gradients
                loss.backward()

                if (step + 1) % accumulation_steps == 0:

                    # Calculate elapsed time in minutes and print loss on the tqdm bar
                    elapsed = format_time(time.time() - start_time)
                    avg_train_loss = total_train_loss / (step + 1)
                    # keep track of changing avg_train_loss
                    train_loss_trajectory.append(avg_train_loss)
                    pbar.set_description(
                        f"Epoch:{epoch+1}|Batch:{step}/{len(train_dataloader)}|Time:{elapsed}|Avg. Loss:{avg_train_loss:.4f}|Loss:{loss.item():.4f}"
                    )

                    # Clip the norm of the gradients to 1.0.
                    # This is to help prevent the "exploding gradients" problem.
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    # Update parameters
                    optimizer.step()

                    # Clean the model's previous gradients
                    model.zero_grad()  # Reset gradients tensors

                    # Update the learning rate.
                    scheduler.step()
                    pbar.update()
                if (step + 1) % dev_steps == 0:
                    # Perform validation with the model and log the performance
                    logging.info("Running Validation...")
                    # Put the model in evaluation mode--the dropout layers behave differently
                    # during evaluation.
                    model.eval()
                    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
                        dev_dataloader, model, device, args.task + "_dev",
                        True)
                    for subtask in model.subtasks:
                        dev_subtask_data = dev_subtasks_data[subtask]
                        dev_subtask_prediction_scores = dev_prediction_scores[
                            subtask]
                        dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                            dev_subtask_data, dev_subtask_prediction_scores)
                        logging.info(
                            f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}"
                        )
                        dev_subtasks_validation_statistics[subtask].append(
                            (epoch + 1, step + 1, dev_TP + dev_FN, dev_F1,
                             dev_P, dev_R, dev_TP, dev_FP, dev_FN))

                    # logging.info("DEBUG:Validation on Test")
                    # dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task + "_dev", True)
                    # for subtask in model.subtasks:
                    # 	dev_subtask_data = test_subtasks_data[subtask]
                    # 	dev_subtask_prediction_scores = dev_prediction_scores[subtask]
                    # 	dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(dev_subtask_data, dev_subtask_prediction_scores)
                    # 	logging.info(f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}")
                    # 	dev_subtasks_validation_statistics[subtask].append((epoch + 1, step + 1, dev_TP + dev_FN, dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN))
                    # Put the model back in train setting
                    model.train()

            # Calculate the average loss over all of the batches.
            avg_train_loss = total_train_loss / len(train_dataloader)

            training_time = format_time(time.time() - start_time)

            # Record all statistics from this epoch.
            training_stats.append({
                'epoch': epoch + 1,
                'Training Loss': avg_train_loss,
                'Training Time': training_time
            })

            # Save the loss trajectory
            epoch_train_loss.append(train_loss_trajectory)

        logging.info(
            f"Training complete with total Train time:{format_time(time.time()- total_start_time)}"
        )
        log_list(training_stats)

        # Save the model and the Tokenizer here:
        logging.info(
            f"Saving the model and tokenizer in {args.save_directory}")
        model.save_pretrained(args.save_directory)
        # Save each subtask classifiers weights to individual state dicts
        for subtask, classifier in model.classifiers.items():
            classifier_save_file = os.path.join(args.save_directory,
                                                f"{subtask}_classifier.bin")
            logging.info(
                f"Saving the model's {subtask} classifier weights at {classifier_save_file}"
            )
            torch.save(classifier.state_dict(), classifier_save_file)
        tokenizer.save_pretrained(args.save_directory)

        # Plot the train loss trajectory in a plot
        train_loss_trajectory_plot_file = os.path.join(
            args.output_dir, "train_loss_trajectory.png")
        logging.info(
            f"Saving the Train loss trajectory at {train_loss_trajectory_plot_file}"
        )
        plot_train_loss(epoch_train_loss, train_loss_trajectory_plot_file)

        # TODO: Plot the validation performance
        # Save dev_subtasks_validation_statistics
    else:
        logging.info("No training needed. Directly going to evaluation!")

    # Save the model name in the model_config file
    model_config["model"] = "MultiTaskBertForCovidEntityClassification"
    model_config["epochs"] = args.n_epochs

    # Find best threshold for each subtask based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task, True)
    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task + "_dev", True)

    best_test_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_dev_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_test_F1s = {subtask: 0.0 for subtask in model.subtasks}
    best_dev_F1s = {subtask: 0.0 for subtask in model.subtasks}
    test_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    dev_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    # for subtask in model.subtasks:
    # 	test_subtask_data = test_subtasks_data[subtask]
    # 	test_subtask_prediction_scores = test_prediction_scores[subtask]
    # 	for t in thresholds:
    # 		test_F1, test_P, test_R, test_TP, test_FP, test_FN = get_TP_FP_FN(test_subtask_data, test_subtask_prediction_scores, THRESHOLD=t)
    # 		test_subtasks_t_F1_P_Rs[subtask].append((t, test_F1, test_P, test_R, test_TP + test_FN, test_TP, test_FP, test_FN))
    # 		if test_F1 > best_test_F1s[subtask]:
    # 			best_test_thresholds[subtask] = t
    # 			best_test_F1s[subtask] = test_F1

    # 	logging.info(f"Subtask:{subtask:>15}")
    # 	log_list(test_subtasks_t_F1_P_Rs[subtask])
    # 	logging.info(f"Best Test Threshold for subtask: {best_test_thresholds[subtask]}\t Best test F1: {best_test_F1s[subtask]}")

    for subtask in model.subtasks:
        dev_subtask_data = dev_subtasks_data[subtask]
        dev_subtask_prediction_scores = dev_prediction_scores[subtask]
        for t in thresholds:
            dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                dev_subtask_data, dev_subtask_prediction_scores, THRESHOLD=t)
            dev_subtasks_t_F1_P_Rs[subtask].append(
                (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP,
                 dev_FN))
            if dev_F1 > best_dev_F1s[subtask]:
                best_dev_thresholds[subtask] = t
                best_dev_F1s[subtask] = dev_F1

        logging.info(f"Subtask:{subtask:>15}")
        log_list(dev_subtasks_t_F1_P_Rs[subtask])
        logging.info(
            f"Best Dev Threshold for subtask: {best_dev_thresholds[subtask]}\t Best dev F1: {best_dev_F1s[subtask]}"
        )

    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_dev_thresholds
    results["best_dev_F1s"] = best_dev_F1s
    results["dev_t_F1_P_Rs"] = dev_subtasks_t_F1_P_Rs

    # Evaluate on Test
    logging.info("Testing on test dataset")
    # test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task)

    predicted_labels, prediction_scores, gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task)

    # Test
    for subtask in model.subtasks:
        logging.info(f"Testing the trained classifier on subtask: {subtask}")
        # print(len(test_dataloader))
        # print(len(prediction_scores[subtask]))
        # print(len(test_subtasks_data[subtask]))
        results[subtask] = dict()
        cm = metrics.confusion_matrix(gold_labels[subtask],
                                      predicted_labels[subtask])
        classification_report = metrics.classification_report(
            gold_labels[subtask], predicted_labels[subtask], output_dict=True)
        logging.info(cm)
        logging.info(
            metrics.classification_report(gold_labels[subtask],
                                          predicted_labels[subtask]))
        results[subtask]["CM"] = cm.tolist(
        )  # Storing it as list of lists instead of numpy.ndarray
        results[subtask]["Classification Report"] = classification_report

        # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
        EM_score, F1_score, total = get_raw_scores(test_subtasks_data[subtask],
                                                   prediction_scores[subtask])
        logging.info("Word overlap based SQuAD evaluation style metrics:")
        logging.info(f"Total number of cases: {total}")
        logging.info(f"EM_score: {EM_score}")
        logging.info(f"F1_score: {F1_score}")
        results[subtask]["SQuAD_EM"] = EM_score
        results[subtask]["SQuAD_F1"] = F1_score
        results[subtask]["SQuAD_total"] = total
        pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            positive_only=True)
        logging.info(f"Total number of Positive cases: {pos_total}")
        logging.info(f"Pos. EM_score: {pos_EM_score}")
        logging.info(f"Pos. F1_score: {pos_F1_score}")
        results[subtask]["SQuAD_Pos. EM"] = pos_EM_score
        results[subtask]["SQuAD_Pos. F1"] = pos_F1_score
        results[subtask]["SQuAD_Pos. EM_F1_total"] = pos_total

        # New evaluation suggested by Alan
        F1, P, R, TP, FP, FN = get_TP_FP_FN(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            THRESHOLD=best_dev_thresholds[subtask])
        logging.info("New evaluation scores:")
        logging.info(f"F1: {F1}")
        logging.info(f"Precision: {P}")
        logging.info(f"Recall: {R}")
        logging.info(f"True Positive: {TP}")
        logging.info(f"False Positive: {FP}")
        logging.info(f"False Negative: {FN}")
        results[subtask]["F1"] = F1
        results[subtask]["P"] = P
        results[subtask]["R"] = R
        results[subtask]["TP"] = TP
        results[subtask]["FP"] = FP
        results[subtask]["FN"] = FN
        N = TP + FN
        results[subtask]["N"] = N

        # # Top predictions in the Test case
        # prediction_scores[subtask] = np.array(prediction_scores[subtask])
        # sorted_prediction_ids = np.argsort(-prediction_scores[subtask])
        # K = 200
        # logging.info("Top {} predictions:".format(K))
        # logging.info("\t".join(["Tweet", "BERT model input", "candidate chunk", "prediction score", "predicted label", "gold label", "gold chunks"]))
        # for i in range(K):
        # 	instance_id = sorted_prediction_ids[i]
        # 	# text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
        # 	tweet = test_subtasks_data[subtask][instance_id][0].replace("\n", " ")
        # 	chunk = test_subtasks_data[subtask][instance_id][1]
        # 	tokenized_tweet_with_masked_chunk = test_subtasks_data[subtask][instance_id][6]
        # 	if chunk in ["AUTHOR OF THE TWEET", "NEAR AUTHOR OF THE TWEET"]:
        # 		# First element of the text will be considered as AUTHOR OF THE TWEET or NEAR AUTHOR OF THE TWEET
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> </E>")
        # 		# print(tokenized_tweet_with_masked_chunk)
        # 		# print(bert_model_input_text)
        # 		# exit()
        # 	else:
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> " + chunk + " </E>")
        # 	list_to_print = [tweet, bert_model_input_text, chunk, str(prediction_scores[subtask][instance_id]), str(predicted_labels[subtask][instance_id]), str(test_subtasks_data[subtask][instance_id][-1]), str(test_subtasks_data[subtask][instance_id][-2])]
        # 	logging.info("\t".join(list_to_print))

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
Exemplo n.º 2
0
def main():
    # Read all the data instances
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data, subtasks_list = get_multitask_instances_for_valid_tasks(
        task_instances_dict, tag_statistics)

    if args.retrain:
        logging.info("Creating and training the model from 'bert-base-cased' ")
        # Create the save_directory if not exists
        make_dir_if_not_exists(args.save_directory)

        # Initialize tokenizer and model with pretrained weights
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        config = BertConfig.from_pretrained('bert-base-cased')
        config.subtasks = subtasks_list
        # print(config)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            'bert-base-cased', config=config)

        # Add new tokens in tokenizer
        new_special_tokens_dict = {
            "additional_special_tokens": ["<E>", "</E>", "<URL>", "@USER"]
        }
        # new_special_tokens_dict = {"additional_special_tokens": ["<E>", "</E>"]}
        tokenizer.add_special_tokens(new_special_tokens_dict)

        # Add the new embeddings in the weights
        print("Embeddings type:",
              model.bert.embeddings.word_embeddings.weight.data.type())
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        embedding_size = model.bert.embeddings.word_embeddings.weight.size(1)
        new_embeddings = torch.FloatTensor(
            len(new_special_tokens_dict["additional_special_tokens"]),
            embedding_size).uniform_(-0.1, 0.1)
        # new_embeddings = torch.FloatTensor(2, embedding_size).uniform_(-0.1, 0.1)
        print("new_embeddings shape:", new_embeddings.size())
        new_embedding_weight = torch.cat(
            (model.bert.embeddings.word_embeddings.weight.data,
             new_embeddings), 0)
        model.bert.embeddings.word_embeddings.weight.data = new_embedding_weight
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        # Update model config vocab size
        model.config.vocab_size = model.config.vocab_size + len(
            new_special_tokens_dict["additional_special_tokens"])
    else:
        # Load the tokenizer and model from the save_directory
        tokenizer = BertTokenizer.from_pretrained(args.save_directory)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            args.save_directory)
        # print(model.state_dict().keys())
        # TODO save and load the subtask classifier weights separately
        # Load from individual state dicts
        for subtask in model.subtasks:
            model.classifiers[subtask].load_state_dict(
                torch.load(
                    os.path.join(args.save_directory,
                                 f"{subtask}_classifier.bin")))
        # print(model.config)
        # exit()
    model.to(device)
    # Explicitly move the classifiers to device
    for subtask, classifier in model.classifiers.items():
        classifier.to(device)
    entity_start_token_id = tokenizer.convert_tokens_to_ids(["<E>"])[0]

    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    test_data = data
    logging.info("Test Data:")
    total_test_size, pos_subtasks_test_size, neg_subtasks_test_size = log_multitask_data_statistics(
        test_data, model.subtasks)
    logging.info("\n")
    # model_config["train_data"] = {"size":total_train_size, "pos":pos_subtasks_train_size, "neg":neg_subtasks_train_size}
    # model_config["dev_data"] = {"size":total_dev_size, "pos":pos_subtasks_dev_size, "neg":neg_subtasks_dev_size}
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_subtasks_test_size,
        "neg": neg_subtasks_test_size
    }

    # Extract subtasks data for dev and test
    #dev_subtasks_data = split_data_based_on_subtasks(dev_data, model.subtasks)
    test_subtasks_data = split_data_based_on_subtasks(test_data,
                                                      model.subtasks)

    # Load the instances into pytorch dataset
    # train_dataset = COVID19TaskDataset(train_data)
    # dev_dataset = COVID19TaskDataset(dev_data)
    test_dataset = COVID19TaskDataset(test_data)
    logging.info("Loaded the datasets into Pytorch datasets")

    tokenize_collator = TokenizeCollator(tokenizer, model.subtasks,
                                         entity_start_token_id)
    # train_dataloader = DataLoader(train_dataset, batch_size=POSSIBLE_BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=tokenize_collator)
    # dev_dataloader = DataLoader(dev_dataset, batch_size=POSSIBLE_BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=tokenize_collator)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=POSSIBLE_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=tokenize_collator)
    logging.info("Created train and test dataloaders with batch aggregation")

    # Save the model name in the model_config file
    model_config["model"] = "MultiTaskBertForCovidEntityClassification"
    model_config["epochs"] = args.n_epochs

    # Find best threshold for each subtask based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task, True)
    #dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(dev_dataloader, model, device, args.task + "_dev", True)

    best_test_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_dev_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_test_F1s = {subtask: 0.0 for subtask in model.subtasks}
    best_dev_F1s = {subtask: 0.0 for subtask in model.subtasks}
    test_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    dev_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}

    # Evaluate on Test
    logging.info("Testing on test dataset")
    # test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task)

    predicted_labels, prediction_scores, gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task)

    # Test
    for subtask in model.subtasks:
        logging.info(f"Testing the trained classifier on subtask: {subtask}")
        # print(len(test_dataloader))
        # print(len(prediction_scores[subtask]))
        # print(len(test_subtasks_data[subtask]))
        results[subtask] = dict()
        cm = metrics.confusion_matrix(gold_labels[subtask],
                                      predicted_labels[subtask])
        classification_report = metrics.classification_report(
            gold_labels[subtask], predicted_labels[subtask], output_dict=True)
        logging.info(cm)
        logging.info(
            metrics.classification_report(gold_labels[subtask],
                                          predicted_labels[subtask]))
        results[subtask]["CM"] = cm.tolist(
        )  # Storing it as list of lists instead of numpy.ndarray
        results[subtask]["Classification Report"] = classification_report

        # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
        EM_score, F1_score, total = get_raw_scores(test_subtasks_data[subtask],
                                                   prediction_scores[subtask])
        logging.info("Word overlap based SQuAD evaluation style metrics:")
        logging.info(f"Total number of cases: {total}")
        logging.info(f"EM_score: {EM_score}")
        logging.info(f"F1_score: {F1_score}")
        results[subtask]["SQuAD_EM"] = EM_score
        results[subtask]["SQuAD_F1"] = F1_score
        results[subtask]["SQuAD_total"] = total
        pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            positive_only=True)
        logging.info(f"Total number of Positive cases: {pos_total}")
        logging.info(f"Pos. EM_score: {pos_EM_score}")
        logging.info(f"Pos. F1_score: {pos_F1_score}")
        results[subtask]["SQuAD_Pos. EM"] = pos_EM_score
        results[subtask]["SQuAD_Pos. F1"] = pos_F1_score
        results[subtask]["SQuAD_Pos. EM_F1_total"] = pos_total

        # New evaluation suggested by Alan
        F1, P, R, TP, FP, FN = get_TP_FP_FN(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            THRESHOLD=best_dev_thresholds[subtask])
        logging.info("New evaluation scores:")
        logging.info(f"F1: {F1}")
        logging.info(f"Precision: {P}")
        logging.info(f"Recall: {R}")
        logging.info(f"True Positive: {TP}")
        logging.info(f"False Positive: {FP}")
        logging.info(f"False Negative: {FN}")
        results[subtask]["F1"] = F1
        results[subtask]["P"] = P
        results[subtask]["R"] = R
        results[subtask]["TP"] = TP
        results[subtask]["FP"] = FP
        results[subtask]["FN"] = FN
        N = TP + FN
        results[subtask]["N"] = N

        # # Top predictions in the Test case
        # prediction_scores[subtask] = np.array(prediction_scores[subtask])
        # sorted_prediction_ids = np.argsort(-prediction_scores[subtask])
        # K = 200
        # logging.info("Top {} predictions:".format(K))
        # logging.info("\t".join(["Tweet", "BERT model input", "candidate chunk", "prediction score", "predicted label", "gold label", "gold chunks"]))
        # for i in range(K):
        # 	instance_id = sorted_prediction_ids[i]
        # 	# text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
        # 	tweet = test_subtasks_data[subtask][instance_id][0].replace("\n", " ")
        # 	chunk = test_subtasks_data[subtask][instance_id][1]
        # 	tokenized_tweet_with_masked_chunk = test_subtasks_data[subtask][instance_id][6]
        # 	if chunk in ["AUTHOR OF THE TWEET", "NEAR AUTHOR OF THE TWEET"]:
        # 		# First element of the text will be considered as AUTHOR OF THE TWEET or NEAR AUTHOR OF THE TWEET
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> </E>")
        # 		# print(tokenized_tweet_with_masked_chunk)
        # 		# print(bert_model_input_text)
        # 		# exit()
        # 	else:
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> " + chunk + " </E>")
        # 	list_to_print = [tweet, bert_model_input_text, chunk, str(prediction_scores[subtask][instance_id]), str(predicted_labels[subtask][instance_id]), str(test_subtasks_data[subtask][instance_id][-1]), str(test_subtasks_data[subtask][instance_id][-2])]
        # 	logging.info("\t".join(list_to_print))

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
Exemplo n.º 3
0
def main():

    task_instances_dict = load_from_pickle(args.data_file)
    data = extract_instances_for_current_subtask(task_instances_dict,
                                                 args.sub_task)
    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data, test_data = split_instances_in_train_dev_test(data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_train_size, neg_train_size = log_data_statistics(
        train_data)
    logging.info("Dev Data:")
    total_dev_size, pos_dev_size, neg_dev_size = log_data_statistics(dev_data)
    logging.info("Test Data:")
    total_test_size, pos_test_size, neg_test_size = log_data_statistics(
        test_data)
    # logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_train_size,
        "neg": neg_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_dev_size,
        "neg": neg_dev_size
    }
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_test_size,
        "neg": neg_test_size
    }

    # Extract n-gram features from the train data
    # Returned ngrams will be dict of dict
    # TODO: update the feature extractor
    feature2i, i2feature = create_ngram_features_from(train_data)
    logging.info(
        f"Total number of features extracted from train = {len(feature2i)}, {len(i2feature)}"
    )
    model_config["features"] = {"size": len(feature2i)}

    # Extract Feature vectors and labels from train and test data
    train_X, train_Y = convert_data_to_feature_vector_and_labels(
        train_data, feature2i)
    dev_X, dev_Y = convert_data_to_feature_vector_and_labels(
        dev_data, feature2i)
    test_X, test_Y = convert_data_to_feature_vector_and_labels(
        test_data, feature2i)
    logging.info(
        f"Train Data Features = {train_X.shape} and Labels = {len(train_Y)}")
    logging.info(
        f"Dev Data Features = {dev_X.shape} and Labels = {len(dev_Y)}")
    logging.info(
        f"Test Data Features = {test_X.shape} and Labels = {len(test_Y)}")
    model_config["train_data"]["features_shape"] = train_X.shape
    model_config["train_data"]["labels_shape"] = len(train_Y)
    model_config["dev_data"]["features_shape"] = dev_X.shape
    model_config["dev_data"]["labels_shape"] = len(dev_Y)
    model_config["test_data"]["features_shape"] = test_X.shape
    model_config["test_data"]["labels_shape"] = len(test_Y)

    # Train logistic regression classifier
    logging.info("Training the Logistic Regression classifier")
    lr = LogisticRegression(solver='lbfgs', max_iter=1000)
    lr.fit(train_X, train_Y)
    model_config["model"] = "LogisticRegression(solver='lbfgs')"

    # # Find best threshold based on dev set performance
    # thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    #
    # dev_prediction_probs = lr.predict_proba(dev_X)[:, 1]
    # dev_t_F1_P_Rs = list()
    # best_threshold_based_on_F1 = 0.5
    # best_dev_F1 = 0.0
    # for t in thresholds:
    # 	dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(dev_data, dev_prediction_probs, THRESHOLD=t)
    # 	dev_t_F1_P_Rs.append((t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP, dev_FN))
    # 	if dev_F1 > best_dev_F1:
    # 		best_threshold_based_on_F1 = t
    # 		best_dev_F1 = dev_F1
    # # log_list(dev_t_F1_P_Rs)
    # # logging.info(f"Best Threshold: {best_threshold_based_on_F1}\t Best dev F1: {best_dev_F1}")
    # # Save the best dev threshold and dev_F1 in results dict
    # results["best_dev_threshold"] = best_threshold_based_on_F1
    # results["best_dev_F1"] = best_dev_F1
    # results["dev_t_F1_P_Rs"] = dev_t_F1_P_Rs
    # # y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool)

    # Test
    logging.info("Testing the trained classifier")
    predictions = lr.predict(test_X)
    probs = lr.predict_proba(test_X)
    test_Y_prediction_probs = probs[:, 1]
    cm = metrics.confusion_matrix(test_Y, predictions)
    classification_report = metrics.classification_report(test_Y,
                                                          predictions,
                                                          output_dict=True)
    logging.info(cm)
    logging.info(metrics.classification_report(test_Y, predictions))
    results["CM"] = cm.tolist(
    )  # Storing it as list of lists instead of numpy.ndarray
    results["Classification Report"] = classification_report

    # evaluation script
    F1, P, R, TP, FP, FN = get_TP_FP_FN(test_data,
                                        test_Y_prediction_probs,
                                        THRESHOLD=0.5)
    logging.info("New evaluation scores:")
    logging.info(f"F1: {F1}")
    logging.info(f"Precision: {P}")
    logging.info(f"Recall: {R}")
    logging.info(f"True Positive: {TP}")
    logging.info(f"False Positive: {FP}")
    logging.info(f"False Negative: {FN}")
    results["F1"] = F1
    results["P"] = P
    results["R"] = R
    results["TP"] = TP
    results["FP"] = FP
    results["FN"] = FN
    N = TP + FN
    results["N"] = N

    # Save the model and features in pickle file
    model_and_features_save_file = os.path.join(args.output_dir,
                                                "model_and_features.pkl")
    logging.info(
        f"Saving LR model and features at {model_and_features_save_file}")
    save_in_pickle((lr, feature2i, i2feature), model_and_features_save_file)

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
def main():
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data = extract_instances_for_current_subtask(task_instances_dict,
                                                 args.sub_task)
    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data, test_data = split_instances_in_train_dev_test(data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_train_size, neg_train_size = log_data_statistics(
        train_data)
    logging.info("Dev Data:")
    total_dev_size, pos_dev_size, neg_dev_size = log_data_statistics(dev_data)
    logging.info("Test Data:")
    total_test_size, pos_test_size, neg_test_size = log_data_statistics(
        test_data)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_train_size,
        "neg": neg_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_dev_size,
        "neg": neg_dev_size
    }
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_test_size,
        "neg": neg_test_size
    }

    # Extract n-gram features from the train data
    # Returned ngrams will be dict of dict
    # TODO: update the feature extractor
    feature2i, i2feature = create_ngram_features_from(train_data)
    logging.info(
        f"Total number of features extracted from train = {len(feature2i)}, {len(i2feature)}"
    )
    model_config["features"] = {"size": len(feature2i)}

    # Extract Feature vectors and labels from train and test data
    train_X, train_Y = convert_data_to_feature_vector_and_labels(
        train_data, feature2i)
    dev_X, dev_Y = convert_data_to_feature_vector_and_labels(
        dev_data, feature2i)
    test_X, test_Y = convert_data_to_feature_vector_and_labels(
        test_data, feature2i)
    logging.info(
        f"Train Data Features = {train_X.shape} and Labels = {len(train_Y)}")
    logging.info(
        f"Dev Data Features = {dev_X.shape} and Labels = {len(dev_Y)}")
    logging.info(
        f"Test Data Features = {test_X.shape} and Labels = {len(test_Y)}")
    model_config["train_data"]["features_shape"] = train_X.shape
    model_config["train_data"]["labels_shape"] = len(train_Y)
    model_config["dev_data"]["features_shape"] = dev_X.shape
    model_config["dev_data"]["labels_shape"] = len(dev_Y)
    model_config["test_data"]["features_shape"] = test_X.shape
    model_config["test_data"]["labels_shape"] = len(test_Y)

    # Train logistic regression classifier
    logging.info("Training the Logistic Regression classifier")
    lr = LogisticRegression(solver='lbfgs')
    lr.fit(train_X, train_Y)
    model_config["model"] = "LogisticRegression(solver='lbfgs')"

    # Find best threshold based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

    dev_prediction_probs = lr.predict_proba(dev_X)[:, 1]
    dev_t_F1_P_Rs = list()
    best_threshold_based_on_F1 = 0.5
    best_dev_F1 = 0.0
    for t in thresholds:
        dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
            dev_data, dev_prediction_probs, THRESHOLD=t)
        dev_t_F1_P_Rs.append(
            (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP, dev_FN))
        if dev_F1 > best_dev_F1:
            best_threshold_based_on_F1 = t
            best_dev_F1 = dev_F1
    log_list(dev_t_F1_P_Rs)
    logging.info(
        f"Best Threshold: {best_threshold_based_on_F1}\t Best dev F1: {best_dev_F1}"
    )
    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_threshold_based_on_F1
    results["best_dev_F1"] = best_dev_F1
    results["dev_t_F1_P_Rs"] = dev_t_F1_P_Rs
    # y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool)

    # Test
    logging.info("Testing the trained classifier")
    predictions = lr.predict(test_X)
    probs = lr.predict_proba(test_X)
    test_Y_prediction_probs = probs[:, 1]
    cm = metrics.confusion_matrix(test_Y, predictions)
    classification_report = metrics.classification_report(test_Y,
                                                          predictions,
                                                          output_dict=True)
    logging.info(cm)
    logging.info(metrics.classification_report(test_Y, predictions))
    results["CM"] = cm.tolist(
    )  # Storing it as list of lists instead of numpy.ndarray
    results["Classification Report"] = classification_report

    # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
    EM_score, F1_score, total = get_raw_scores(test_data,
                                               test_Y_prediction_probs)
    logging.info("Word overlap based SQuAD evaluation style metrics:")
    logging.info(f"Total number of cases: {total}")
    logging.info(f"EM_score: {EM_score}")
    logging.info(f"F1_score: {F1_score}")
    results["SQuAD_EM"] = EM_score
    results["SQuAD_F1"] = F1_score
    results["SQuAD_total"] = total
    pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
        test_data, test_Y_prediction_probs, positive_only=True)
    logging.info(f"Total number of Positive cases: {pos_total}")
    logging.info(f"Pos. EM_score: {pos_EM_score}")
    logging.info(f"Pos. F1_score: {pos_F1_score}")
    results["SQuAD_Pos. EM"] = pos_EM_score
    results["SQuAD_Pos. F1"] = pos_F1_score
    results["SQuAD_Pos. EM_F1_total"] = pos_total

    # New evaluation suggested by Alan
    F1, P, R, TP, FP, FN = get_TP_FP_FN(test_data,
                                        test_Y_prediction_probs,
                                        THRESHOLD=best_threshold_based_on_F1)
    logging.info("New evaluation scores:")
    logging.info(f"F1: {F1}")
    logging.info(f"Precision: {P}")
    logging.info(f"Recall: {R}")
    logging.info(f"True Positive: {TP}")
    logging.info(f"False Positive: {FP}")
    logging.info(f"False Negative: {FN}")
    results["F1"] = F1
    results["P"] = P
    results["R"] = R
    results["TP"] = TP
    results["FP"] = FP
    results["FN"] = FN
    N = TP + FN
    results["N"] = N

    # Top predictions in the Test case
    sorted_prediction_ids = np.argsort(-test_Y_prediction_probs)
    K = 30
    logging.info("Top {} predictions:".format(K))
    for i in range(K):
        instance_id = sorted_prediction_ids[i]
        # text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
        list_to_print = [
            test_data[instance_id][0], test_data[instance_id][6],
            test_data[instance_id][1],
            str(test_Y_prediction_probs[instance_id]),
            str(test_Y[instance_id]),
            str(test_data[instance_id][-1]),
            str(test_data[instance_id][-2])
        ]
        logging.info("\t".join(list_to_print))

    # Top feature analysis
    coefs = lr.coef_[0]
    K = 10
    sorted_feature_ids = np.argsort(-coefs)
    logging.info("Top {} features:".format(K))
    for i in range(K):
        feature_id = sorted_feature_ids[i]
        logging.info(f"{i2feature[feature_id]}\t{coefs[feature_id]}")

    # Plot the precision recall curve
    save_figure_file = os.path.join(args.output_dir,
                                    "Precision Recall Curve.png")
    logging.info(f"Saving precision recall curve at {save_figure_file}")
    disp = plot_precision_recall_curve(lr, test_X, test_Y)
    disp.ax_.set_title('2-class Precision-Recall curve')
    disp.ax_.figure.savefig(save_figure_file)

    # Save the model and features in pickle file
    model_and_features_save_file = os.path.join(args.output_dir,
                                                "model_and_features.pkl")
    logging.info(
        f"Saving LR model and features at {model_and_features_save_file}")
    save_in_pickle((lr, feature2i, i2feature), model_and_features_save_file)

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
Exemplo n.º 5
0
def main():
    # Read all the data instances
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data, subtasks_list = get_multitask_instances_for_valid_tasks(
        task_instances_dict, tag_statistics)
    data = add_marker_for_loss_ignore(
        data, 1.0 if args.loss_for_no_consensus else 0.0)

    if args.retrain:
        if args.large_bert:
            model_name = "bert-large-cased"
        elif args.covid_bert:
            model_name = "digitalepidemiologylab/covid-twitter-bert"
        else:
            model_name = "bert-base-cased"

        logging.info("Creating and training the model from '" + model_name +
                     "'")
        # Create the save_directory if not exists
        make_dir_if_not_exists(args.save_directory)

        # Initialize tokenizer and model with pretrained weights
        tokenizer = BertTokenizer.from_pretrained(model_name)
        config = BertConfig.from_pretrained(model_name)
        config.subtasks = subtasks_list
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            model_name, config=config)

        # Add new tokens in tokenizer
        new_special_tokens_dict = {
            "additional_special_tokens": ["<E>", "</E>", "<URL>", "@USER"]
        }
        tokenizer.add_special_tokens(new_special_tokens_dict)

        # Add the new embeddings in the weights
        print("Embeddings type:",
              model.bert.embeddings.word_embeddings.weight.data.type())
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        embedding_size = model.bert.embeddings.word_embeddings.weight.size(1)
        new_embeddings = torch.FloatTensor(
            len(new_special_tokens_dict["additional_special_tokens"]),
            embedding_size).uniform_(-0.1, 0.1)
        # new_embeddings = torch.FloatTensor(2, embedding_size).uniform_(-0.1, 0.1)
        print("new_embeddings shape:", new_embeddings.size())
        new_embedding_weight = torch.cat(
            (model.bert.embeddings.word_embeddings.weight.data,
             new_embeddings), 0)
        model.bert.embeddings.word_embeddings.weight.data = new_embedding_weight
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        # Update model config vocab size
        model.config.vocab_size = model.config.vocab_size + len(
            new_special_tokens_dict["additional_special_tokens"])
    else:
        # Load the tokenizer and model from the save_directory
        tokenizer = BertTokenizer.from_pretrained(args.save_directory)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            args.save_directory)
        # Load from individual state dicts
        for subtask in model.subtasks:
            model.classifiers[subtask].load_state_dict(
                torch.load(
                    os.path.join(args.save_directory,
                                 f"{subtask}_classifier.bin")))
    model.to(device)
    if args.wandb:
        wandb.watch(model)

    # Explicitly move the classifiers to device
    for subtask, classifier in model.classifiers.items():
        classifier.to(device)
    for subtask, classifier in model.context_vectors.items():
        classifier.to(device)

    entity_start_token_id = tokenizer.convert_tokens_to_ids(["<E>"])[0]
    entity_end_token_id = tokenizer.convert_tokens_to_ids(["</E>"])[0]

    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data = split_multitask_instances_in_train_dev(data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_subtasks_train_size, neg_subtasks_train_size = log_multitask_data_statistics(
        train_data, model.subtasks)
    logging.info("Dev Data:")
    total_dev_size, pos_subtasks_dev_size, neg_subtasks_dev_size = log_multitask_data_statistics(
        dev_data, model.subtasks)
    #logging.info("Test Data:")
    #total_test_size, pos_subtasks_test_size, neg_subtasks_test_size = log_multitask_data_statistics(test_data, model.subtasks)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_subtasks_train_size,
        "neg": neg_subtasks_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_subtasks_dev_size,
        "neg": neg_subtasks_dev_size
    }
    #model_config["test_data"] = {"size":total_test_size, "pos":pos_subtasks_test_size, "neg":neg_subtasks_test_size}

    # Extract subtasks data for dev and test
    train_subtasks_data = split_data_based_on_subtasks(train_data,
                                                       model.subtasks)
    dev_subtasks_data = split_data_based_on_subtasks(dev_data, model.subtasks)
    #test_subtasks_data = split_data_based_on_subtasks(test_data, model.subtasks)

    # Load the instances into pytorch dataset
    train_dataset = COVID19TaskDataset(train_data)
    dev_dataset = COVID19TaskDataset(dev_data)
    #test_dataset = COVID19TaskDataset(test_data)
    logging.info("Loaded the datasets into Pytorch datasets")

    tokenize_collator = TokenizeCollator(tokenizer, model.subtasks,
                                         entity_start_token_id,
                                         entity_end_token_id)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=POSSIBLE_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=tokenize_collator)
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=POSSIBLE_BATCH_SIZE,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=tokenize_collator)
    #test_dataloader = DataLoader(test_dataset, batch_size=POSSIBLE_BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=tokenize_collator)
    logging.info("Created train and test dataloaders with batch aggregation")

    # Only retrain if needed
    if args.retrain:
        optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
        logging.info("Created model optimizer")
        #if args.sentence_level_classify:
        #    args.n_epochs += 2
        epochs = args.n_epochs

        # Total number of training steps is [number of batches] x [number of epochs].
        total_steps = len(train_dataloader) * epochs

        # Create the learning rate scheduler.
        # NOTE: num_warmup_steps = 0 is the Default value in run_glue.py
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=total_steps)
        # We'll store a number of quantities such as training and validation loss, validation accuracy, and timings.
        training_stats = []
        print("\n\n\n ====== Training for task", args.task,
              "=============\n\n\n")
        logging.info(f"Initiating training loop for {args.n_epochs} epochs...")
        print(model.state_dict().keys())

        total_start_time = time.time()

        # Find the accumulation steps
        accumulation_steps = args.batch_size / POSSIBLE_BATCH_SIZE

        # Dev validation trajectory
        epoch_train_loss = list()
        train_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        dev_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        best_dev_F1 = 0
        for epoch in range(epochs):

            logging.info(f"Initiating Epoch {epoch+1}:")

            # Reset the total loss for each epoch.
            total_train_loss = 0
            train_loss_trajectory = list()

            # Reset timer for each epoch
            start_time = time.time()
            model.train()

            dev_log_frequency = 5
            n_steps = len(train_dataloader)
            dev_steps = int(n_steps / dev_log_frequency)
            for step, batch in enumerate(train_dataloader):
                # Upload labels of each subtask to device
                for subtask in model.subtasks:
                    subtask_labels = batch["gold_labels"][subtask]
                    subtask_labels = subtask_labels.to(device)
                    batch["gold_labels"][subtask] = subtask_labels
                    batch["label_ignore_loss"][subtask] = batch[
                        "label_ignore_loss"][subtask].to(device)

                # Forward
                input_dict = {
                    "input_ids":
                    batch["input_ids"].to(device),
                    "entity_start_positions":
                    batch["entity_start_positions"].to(device),
                    "entity_end_positions":
                    batch["entity_end_positions"].to(device),
                    "labels":
                    batch["gold_labels"],
                    "label_weight":
                    batch["label_ignore_loss"]
                }

                input_ids = batch["input_ids"]
                entity_start_positions = batch["entity_start_positions"]
                gold_labels = batch["gold_labels"]
                batch_data = batch["batch_data"]
                loss, logits = model(**input_dict)

                # Accumulate loss
                total_train_loss += loss.item()

                # Backward: compute gradients
                loss.backward()

                if (step + 1) % accumulation_steps == 0:
                    # Calculate elapsed time in minutes and print loss on the tqdm bar
                    elapsed = format_time(time.time() - start_time)
                    avg_train_loss = total_train_loss / (step + 1)

                    # keep track of changing avg_train_loss
                    train_loss_trajectory.append(avg_train_loss)
                    if (step + 1) % (accumulation_steps * 20) == 0:
                        print(
                            f"Epoch:{epoch+1}|Batch:{step}/{len(train_dataloader)}|Time:{elapsed}|Avg. Loss:{avg_train_loss:.4f}|Loss:{loss.item():.4f}"
                        )

                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()

                    # Clean the model's previous gradients
                    model.zero_grad()
                    scheduler.step()

            # Calculate the average loss over all of the batches.
            avg_train_loss = total_train_loss / len(train_dataloader)

            # Perform validation with the model and log the performance
            print("\n")
            logging.info("Running Validation...")
            # Put the model in evaluation mode--the dropout layers behave differently during evaluation.
            model.eval()
            dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
                dev_dataloader, model, device, args.task + "_dev", True)

            wandb_log_dict = {"Train Loss": avg_train_loss}
            print("Dev Set:")
            collect_TP_FP_FN = {"TP": 0, "FP": 0, "FN": 0}
            for subtask in model.subtasks:
                dev_subtask_data = dev_subtasks_data[subtask]
                dev_subtask_prediction_scores = dev_prediction_scores[subtask]
                dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                    dev_subtask_data,
                    dev_subtask_prediction_scores,
                    task=subtask)
                if subtask not in IGNORE_TASKS:
                    collect_TP_FP_FN["TP"] += dev_TP
                    collect_TP_FP_FN["FP"] += dev_FP
                    collect_TP_FP_FN["FN"] += dev_FN
                else:
                    print("IGNORE: ", end="")

                print(
                    f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}"
                )
                dev_subtasks_validation_statistics[subtask].append(
                    (epoch + 1, step + 1, dev_TP + dev_FN, dev_F1, dev_P,
                     dev_R, dev_TP, dev_FP, dev_FN))

                wandb_log_dict["Dev_ " + subtask + "_F1"] = dev_F1
                wandb_log_dict["Dev_ " + subtask + "_P"] = dev_P
                wandb_log_dict["Dev_ " + subtask + "_R"] = dev_R

            dev_macro_P = collect_TP_FP_FN["TP"] / (collect_TP_FP_FN["TP"] +
                                                    collect_TP_FP_FN["FP"])
            dev_macro_R = collect_TP_FP_FN["TP"] / (collect_TP_FP_FN["TP"] +
                                                    collect_TP_FP_FN["FN"])
            dev_macro_F1 = (2 * dev_macro_P * dev_macro_R) / (dev_macro_P +
                                                              dev_macro_R)
            print(collect_TP_FP_FN)
            print("dev_macro_P:", dev_macro_P, "\ndev_macro_R:", dev_macro_R,
                  "\ndev_macro_F1:", dev_macro_F1, "\n")
            wandb_log_dict["Dev_macro_F1"] = dev_macro_F1
            wandb_log_dict["Dev_macro_P"] = dev_macro_P
            wandb_log_dict["Dev_macro_R"] = dev_macro_R

            if args.wandb:
                wandb.log(wandb_log_dict)

            if dev_macro_F1 > best_dev_F1:
                best_dev_F1 = dev_macro_F1
                print("NEW BEST F1:", best_dev_F1, " Saving checkpoint now.")
                torch.save(model.state_dict(), args.output_dir + "/ckpt.pth")
                #print(model.state_dict().keys())
                #model.save_pretrained(args.save_directory)
            model.train()

            training_time = format_time(time.time() - start_time)

            # Record all statistics from this epoch.
            training_stats.append({
                'epoch': epoch + 1,
                'Training Loss': avg_train_loss,
                'Training Time': training_time
            })

            # Save the loss trajectory
            epoch_train_loss.append(train_loss_trajectory)
            print("\n\n")

        logging.info(
            f"Training complete with total Train time:{format_time(time.time()- total_start_time)}"
        )
        log_list(training_stats)

        model.load_state_dict(torch.load(args.output_dir + "/ckpt.pth"))
        model.eval()
        # Save the model and the Tokenizer here:
        #logging.info(f"Saving the model and tokenizer in {args.save_directory}")
        #model.save_pretrained(args.save_directory)

        # Save each subtask classifiers weights to individual state dicts
        #for subtask, classifier in model.classifiers.items():
        #    classifier_save_file = os.path.join(args.save_directory, f"{subtask}_classifier.bin")
        #    logging.info(f"Saving the model's {subtask} classifier weights at {classifier_save_file}")
        #    torch.save(classifier.state_dict(), classifier_save_file)
        #tokenizer.save_pretrained(args.save_directory)

        # Plot the train loss trajectory in a plot
        #train_loss_trajectory_plot_file = os.path.join(args.output_dir, "train_loss_trajectory.png")
        #logging.info(f"Saving the Train loss trajectory at {train_loss_trajectory_plot_file}")
        #print(epoch_train_loss)

        # TODO: Plot the validation performance
        # Save dev_subtasks_validation_statistics
    else:
        raise
        logging.info("No training needed. Directly going to evaluation!")

    # Save the model name in the model_config file
    model_config["model"] = "MultiTaskBertForCovidEntityClassification"
    model_config["epochs"] = args.n_epochs

    # Find best threshold for each subtask based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    #test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task, True)
    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task + "_dev", True)

    best_test_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_dev_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_test_F1s = {subtask: 0.0 for subtask in model.subtasks}
    best_dev_F1s = {subtask: 0.0 for subtask in model.subtasks}
    #test_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    dev_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}

    for subtask in model.subtasks:
        dev_subtask_data = dev_subtasks_data[subtask]
        dev_subtask_prediction_scores = dev_prediction_scores[subtask]
        for t in thresholds:
            dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                dev_subtask_data,
                dev_subtask_prediction_scores,
                THRESHOLD=t,
                task=subtask)
            dev_subtasks_t_F1_P_Rs[subtask].append(
                (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP,
                 dev_FN))
            if dev_F1 > best_dev_F1s[subtask]:
                best_dev_thresholds[subtask] = t
                best_dev_F1s[subtask] = dev_F1

        logging.info(f"Subtask:{subtask:>15}")
        log_list(dev_subtasks_t_F1_P_Rs[subtask])
        logging.info(
            f"Best Dev Threshold for subtask: {best_dev_thresholds[subtask]}\t Best dev F1: {best_dev_F1s[subtask]}"
        )

    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_dev_thresholds
    results["best_dev_F1s"] = best_dev_F1s
    results["dev_t_F1_P_Rs"] = dev_subtasks_t_F1_P_Rs

    # Evaluate on Test
    logging.info("Testing on eval dataset")
    predicted_labels, prediction_scores, gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task)

    # Test
    for subtask in model.subtasks:
        logging.info(f"\nTesting the trained classifier on subtask: {subtask}")

        results[subtask] = dict()
        cm = metrics.confusion_matrix(gold_labels[subtask],
                                      predicted_labels[subtask])
        classification_report = metrics.classification_report(
            gold_labels[subtask], predicted_labels[subtask], output_dict=True)
        logging.info(cm)
        logging.info(
            metrics.classification_report(gold_labels[subtask],
                                          predicted_labels[subtask]))
        results[subtask]["CM"] = cm.tolist(
        )  # Storing it as list of lists instead of numpy.ndarray
        results[subtask]["Classification Report"] = classification_report

        # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
        EM_score, F1_score, total = get_raw_scores(dev_subtasks_data[subtask],
                                                   prediction_scores[subtask])
        logging.info("Word overlap based SQuAD evaluation style metrics:")
        logging.info(f"Total number of cases: {total}")
        logging.info(f"EM_score: {EM_score}")
        logging.info(f"F1_score: {F1_score}")
        results[subtask]["SQuAD_EM"] = EM_score
        results[subtask]["SQuAD_F1"] = F1_score
        results[subtask]["SQuAD_total"] = total
        pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
            dev_subtasks_data[subtask],
            prediction_scores[subtask],
            positive_only=True)
        logging.info(f"Total number of Positive cases: {pos_total}")
        logging.info(f"Pos. EM_score: {pos_EM_score}")
        logging.info(f"Pos. F1_score: {pos_F1_score}")
        results[subtask]["SQuAD_Pos. EM"] = pos_EM_score
        results[subtask]["SQuAD_Pos. F1"] = pos_F1_score
        results[subtask]["SQuAD_Pos. EM_F1_total"] = pos_total

        # New evaluation suggested by Alan
        F1, P, R, TP, FP, FN = get_TP_FP_FN(
            dev_subtasks_data[subtask],
            prediction_scores[subtask],
            THRESHOLD=best_dev_thresholds[subtask],
            task=subtask)
        logging.info("New evaluation scores:")
        logging.info(f"F1: {F1}")
        logging.info(f"Precision: {P}")
        logging.info(f"Recall: {R}")
        logging.info(f"True Positive: {TP}")
        logging.info(f"False Positive: {FP}")
        logging.info(f"False Negative: {FN}")
        results[subtask]["F1"] = F1
        results[subtask]["P"] = P
        results[subtask]["R"] = R
        results[subtask]["TP"] = TP
        results[subtask]["FP"] = FP
        results[subtask]["FN"] = FN
        N = TP + FN
        results[subtask]["N"] = N

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)