Exemplo n.º 1
0
def main():
    # initialise the models
    vmodel = VideoNet().to(device)
    amodel = AudioNet().to(device)
    avmodel = AVNet().to(device)
    vmodel.load_state_dict(torch.load('vmodel_final.pt'))
    amodel.load_state_dict(torch.load('amodel_final.pt'))
    avmodel.load_state_dict(torch.load('avmodel_final.pt'))
    print('loaded model')
    params = list(vmodel.parameters())+list(amodel.parameters())+list(avmodel.parameters())
    # optimiser = optim.Adam(params, lr=LR)
    optimiser = optim.SGD(params, lr=LR, momentum=0.9)

    list_vid = os.listdir('data/train/full_vid')  # ensure no extra files like .DS_Store are present
    train_list, val_list = utils.split_data(list_vid, 0.8, 0.2)
    # log the list for reference
    utils.log_list(train_list, 'data/train_list.txt')
    utils.log_list(val_list, 'data/val_list.txt')
    # uncomment following to read previous list
    # train_list = utils.read_list('data/train_list.txt')
    # val_list = utils.read_list('data/val_list.txt')
    train_list = ['video_001.mp4']
    composed = transforms.Compose([Resize(256), RandomCrop(224)])
    # composed = transforms.Compose([Resize(256)])
    train_loader = torch.utils.data.DataLoader(AVDataset(train_list[:1], transform=composed), batch_size=batch_size, shuffle=False, num_workers=4)
    val_loader = torch.utils.data.DataLoader(AVDataset(train_list[:1], transform=composed), batch_size=batch_size,shuffle=False, num_workers=4)
    l,p,cam=val(vmodel,amodel,avmodel,val_loader)
    print(p,cam.shape)
    import skvideo.io
    vids=skvideo.io.vread('data/train/'+'snippet/video_001.mp4')
    # print('vids',vids)
    findcam(np.expand_dims(vids,0),np.abs(cam.cpu().numpy()))
Exemplo n.º 2
0
def main():
    # initialise the models
    vmodel = VideoNet().to(device)
    amodel = AudioNet().to(device)
    avmodel = AVNet().to(device)
#     vmodel.load_state_dict(torch.load('./pretrained/tfvmodel.pt'))
#     amodel.load_state_dict(torch.load('./pretrained/tfamodel.pt'))
#     avmodel.load_state_dict(torch.load('./pretrained/tfavmodel.pt'))
#     print('loaded model')
    params = list(vmodel.parameters())+list(amodel.parameters())+list(avmodel.parameters())
    optimiser = optim.Adam(params, lr=LR)
    list_vid = os.listdir('data/train/full_vid')  # ensure no extra files like .DS_Store are present
    train_list, val_list = utils.split_data(list_vid, 0.8, 0.2)
    # log the list for reference
    utils.log_list(train_list, 'data/train_list.txt')
    utils.log_list(val_list, 'data/val_list.txt')
    # uncomment following to read previous list
    # train_list = utils.read_list('data/train_list.txt')
    # val_list = utils.read_list('data/val_list.txt')
    composed = transforms.Compose([Resize(256), RandomCrop(224)])
    train_loader = torch.utils.data.DataLoader(AVDataset(train_list, transform=composed), batch_size=batch_size, shuffle=True, num_workers=6)
    val_loader = torch.utils.data.DataLoader(AVDataset(val_list, transform=composed), batch_size=test_batch_size,shuffle=True, num_workers=6)
    train(vmodel, amodel, avmodel, optimiser, nepochs, train_loader, val_loader)
def main():
    # Read all the data instances
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data, subtasks_list = get_multitask_instances_for_valid_tasks(
        task_instances_dict, tag_statistics)

    if args.retrain:
        logging.info("Creating and training the model from 'bert-base-cased' ")
        # Create the save_directory if not exists
        make_dir_if_not_exists(args.save_directory)

        # Initialize tokenizer and model with pretrained weights
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        config = BertConfig.from_pretrained('bert-base-cased')
        config.subtasks = subtasks_list
        # print(config)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            'bert-base-cased', config=config)

        # Add new tokens in tokenizer
        new_special_tokens_dict = {
            "additional_special_tokens": ["<E>", "</E>", "<URL>", "@USER"]
        }
        # new_special_tokens_dict = {"additional_special_tokens": ["<E>", "</E>"]}
        tokenizer.add_special_tokens(new_special_tokens_dict)

        # Add the new embeddings in the weights
        print("Embeddings type:",
              model.bert.embeddings.word_embeddings.weight.data.type())
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        embedding_size = model.bert.embeddings.word_embeddings.weight.size(1)
        new_embeddings = torch.FloatTensor(
            len(new_special_tokens_dict["additional_special_tokens"]),
            embedding_size).uniform_(-0.1, 0.1)
        # new_embeddings = torch.FloatTensor(2, embedding_size).uniform_(-0.1, 0.1)
        print("new_embeddings shape:", new_embeddings.size())
        new_embedding_weight = torch.cat(
            (model.bert.embeddings.word_embeddings.weight.data,
             new_embeddings), 0)
        model.bert.embeddings.word_embeddings.weight.data = new_embedding_weight
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        # Update model config vocab size
        model.config.vocab_size = model.config.vocab_size + len(
            new_special_tokens_dict["additional_special_tokens"])
    else:
        # Load the tokenizer and model from the save_directory
        tokenizer = BertTokenizer.from_pretrained(args.save_directory)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            args.save_directory)
        # print(model.state_dict().keys())
        # TODO save and load the subtask classifier weights separately
        # Load from individual state dicts
        for subtask in model.subtasks:
            model.classifiers[subtask].load_state_dict(
                torch.load(
                    os.path.join(args.save_directory,
                                 f"{subtask}_classifier.bin")))
        # print(model.config)
        # exit()
    model.to(device)
    # Explicitly move the classifiers to device
    for subtask, classifier in model.classifiers.items():
        classifier.to(device)
    entity_start_token_id = tokenizer.convert_tokens_to_ids(["<E>"])[0]

    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data, test_data = split_multitask_instances_in_train_dev_test(
        data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_subtasks_train_size, neg_subtasks_train_size = log_multitask_data_statistics(
        train_data, model.subtasks)
    logging.info("Dev Data:")
    total_dev_size, pos_subtasks_dev_size, neg_subtasks_dev_size = log_multitask_data_statistics(
        dev_data, model.subtasks)
    logging.info("Test Data:")
    total_test_size, pos_subtasks_test_size, neg_subtasks_test_size = log_multitask_data_statistics(
        test_data, model.subtasks)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_subtasks_train_size,
        "neg": neg_subtasks_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_subtasks_dev_size,
        "neg": neg_subtasks_dev_size
    }
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_subtasks_test_size,
        "neg": neg_subtasks_test_size
    }

    # Extract subtasks data for dev and test
    dev_subtasks_data = split_data_based_on_subtasks(dev_data, model.subtasks)
    test_subtasks_data = split_data_based_on_subtasks(test_data,
                                                      model.subtasks)

    # Load the instances into pytorch dataset
    train_dataset = COVID19TaskDataset(train_data)
    dev_dataset = COVID19TaskDataset(dev_data)
    test_dataset = COVID19TaskDataset(test_data)
    logging.info("Loaded the datasets into Pytorch datasets")

    tokenize_collator = TokenizeCollator(tokenizer, model.subtasks,
                                         entity_start_token_id)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=POSSIBLE_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=tokenize_collator)
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=POSSIBLE_BATCH_SIZE,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=tokenize_collator)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=POSSIBLE_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=tokenize_collator)
    logging.info("Created train and test dataloaders with batch aggregation")

    # Only retrain if needed
    if args.retrain:
        print('DO RETRAIN')
        ##################################################################################################
        # NOTE: Training Tutorial Reference
        # https://mccormickml.com/2019/07/22/BERT-fine-tuning/#41-bertforsequenceclassification
        ##################################################################################################

        # Create an optimizer training schedule for the BERT text classification model
        # NOTE: AdamW is a class from the huggingface library (as opposed to pytorch)
        # I believe the 'W' stands for 'Weight Decay fix"
        # Recommended Schedule for BERT fine-tuning as per the paper
        # Batch size: 16, 32
        # Learning rate (Adam): 5e-5, 3e-5, 2e-5
        # Number of epochs: 2, 3, 4
        optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
        logging.info("Created model optimizer")
        # Number of training epochs. The BERT authors recommend between 2 and 4.
        # We chose to run for 4, but we'll see later that this may be over-fitting the
        # training data.
        epochs = args.n_epochs

        # Total number of training steps is [number of batches] x [number of epochs].
        # (Note that this is not the same as the number of training samples).
        total_steps = len(train_dataloader) * epochs

        # Create the learning rate scheduler.
        # NOTE: num_warmup_steps = 0 is the Default value in run_glue.py
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=total_steps)
        # We'll store a number of quantities such as training and validation loss,
        # validation accuracy, and timings.
        training_stats = []

        logging.info(f"Initiating training loop for {args.n_epochs} epochs...")
        # Measure the total training time for the whole run.
        total_start_time = time.time()

        # Find the accumulation steps
        accumulation_steps = args.batch_size / POSSIBLE_BATCH_SIZE

        # Loss trajectory for epochs
        epoch_train_loss = list()
        # Dev validation trajectory
        dev_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        for epoch in range(epochs):
            pbar = tqdm(train_dataloader)
            logging.info(f"Initiating Epoch {epoch+1}:")
            # Reset the total loss for each epoch.
            total_train_loss = 0
            train_loss_trajectory = list()

            # Reset timer for each epoch
            start_time = time.time()
            model.train()

            dev_log_frequency = 5
            n_steps = len(train_dataloader)
            dev_steps = int(n_steps / dev_log_frequency)
            for step, batch in enumerate(pbar):
                # Upload labels of each subtask to device
                for subtask in model.subtasks:
                    subtask_labels = batch["gold_labels"][subtask]
                    subtask_labels = subtask_labels.to(device)
                    # print("HAHAHAHAH:", subtask_labels.is_cuda)
                    batch["gold_labels"][subtask] = subtask_labels
                    # print("HAHAHAHAH:", batch["gold_labels"][subtask].is_cuda)
                # Forward
                input_dict = {
                    "input_ids":
                    batch["input_ids"].to(device),
                    "entity_start_positions":
                    batch["entity_start_positions"].to(device),
                    "labels":
                    batch["gold_labels"]
                }

                input_ids = batch["input_ids"]
                entity_start_positions = batch["entity_start_positions"]
                gold_labels = batch["gold_labels"]
                batch_data = batch["batch_data"]
                loss, logits = model(**input_dict)
                # loss = loss / accumulation_steps
                # Accumulate loss
                total_train_loss += loss.item()

                # Backward: compute gradients
                loss.backward()

                if (step + 1) % accumulation_steps == 0:

                    # Calculate elapsed time in minutes and print loss on the tqdm bar
                    elapsed = format_time(time.time() - start_time)
                    avg_train_loss = total_train_loss / (step + 1)
                    # keep track of changing avg_train_loss
                    train_loss_trajectory.append(avg_train_loss)
                    pbar.set_description(
                        f"Epoch:{epoch+1}|Batch:{step}/{len(train_dataloader)}|Time:{elapsed}|Avg. Loss:{avg_train_loss:.4f}|Loss:{loss.item():.4f}"
                    )

                    # Clip the norm of the gradients to 1.0.
                    # This is to help prevent the "exploding gradients" problem.
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    # Update parameters
                    optimizer.step()

                    # Clean the model's previous gradients
                    model.zero_grad()  # Reset gradients tensors

                    # Update the learning rate.
                    scheduler.step()
                    pbar.update()
                if (step + 1) % dev_steps == 0:
                    # Perform validation with the model and log the performance
                    logging.info("Running Validation...")
                    # Put the model in evaluation mode--the dropout layers behave differently
                    # during evaluation.
                    model.eval()
                    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
                        dev_dataloader, model, device, args.task + "_dev",
                        True)
                    for subtask in model.subtasks:
                        dev_subtask_data = dev_subtasks_data[subtask]
                        dev_subtask_prediction_scores = dev_prediction_scores[
                            subtask]
                        dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                            dev_subtask_data, dev_subtask_prediction_scores)
                        logging.info(
                            f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}"
                        )
                        dev_subtasks_validation_statistics[subtask].append(
                            (epoch + 1, step + 1, dev_TP + dev_FN, dev_F1,
                             dev_P, dev_R, dev_TP, dev_FP, dev_FN))

                    # logging.info("DEBUG:Validation on Test")
                    # dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task + "_dev", True)
                    # for subtask in model.subtasks:
                    # 	dev_subtask_data = test_subtasks_data[subtask]
                    # 	dev_subtask_prediction_scores = dev_prediction_scores[subtask]
                    # 	dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(dev_subtask_data, dev_subtask_prediction_scores)
                    # 	logging.info(f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}")
                    # 	dev_subtasks_validation_statistics[subtask].append((epoch + 1, step + 1, dev_TP + dev_FN, dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN))
                    # Put the model back in train setting
                    model.train()

            # Calculate the average loss over all of the batches.
            avg_train_loss = total_train_loss / len(train_dataloader)

            training_time = format_time(time.time() - start_time)

            # Record all statistics from this epoch.
            training_stats.append({
                'epoch': epoch + 1,
                'Training Loss': avg_train_loss,
                'Training Time': training_time
            })

            # Save the loss trajectory
            epoch_train_loss.append(train_loss_trajectory)

        logging.info(
            f"Training complete with total Train time:{format_time(time.time()- total_start_time)}"
        )
        log_list(training_stats)

        # Save the model and the Tokenizer here:
        logging.info(
            f"Saving the model and tokenizer in {args.save_directory}")
        model.save_pretrained(args.save_directory)
        # Save each subtask classifiers weights to individual state dicts
        for subtask, classifier in model.classifiers.items():
            classifier_save_file = os.path.join(args.save_directory,
                                                f"{subtask}_classifier.bin")
            logging.info(
                f"Saving the model's {subtask} classifier weights at {classifier_save_file}"
            )
            torch.save(classifier.state_dict(), classifier_save_file)
        tokenizer.save_pretrained(args.save_directory)

        # Plot the train loss trajectory in a plot
        train_loss_trajectory_plot_file = os.path.join(
            args.output_dir, "train_loss_trajectory.png")
        logging.info(
            f"Saving the Train loss trajectory at {train_loss_trajectory_plot_file}"
        )
        plot_train_loss(epoch_train_loss, train_loss_trajectory_plot_file)

        # TODO: Plot the validation performance
        # Save dev_subtasks_validation_statistics
    else:
        logging.info("No training needed. Directly going to evaluation!")

    # Save the model name in the model_config file
    model_config["model"] = "MultiTaskBertForCovidEntityClassification"
    model_config["epochs"] = args.n_epochs

    # Find best threshold for each subtask based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task, True)
    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task + "_dev", True)

    best_test_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_dev_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_test_F1s = {subtask: 0.0 for subtask in model.subtasks}
    best_dev_F1s = {subtask: 0.0 for subtask in model.subtasks}
    test_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    dev_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    # for subtask in model.subtasks:
    # 	test_subtask_data = test_subtasks_data[subtask]
    # 	test_subtask_prediction_scores = test_prediction_scores[subtask]
    # 	for t in thresholds:
    # 		test_F1, test_P, test_R, test_TP, test_FP, test_FN = get_TP_FP_FN(test_subtask_data, test_subtask_prediction_scores, THRESHOLD=t)
    # 		test_subtasks_t_F1_P_Rs[subtask].append((t, test_F1, test_P, test_R, test_TP + test_FN, test_TP, test_FP, test_FN))
    # 		if test_F1 > best_test_F1s[subtask]:
    # 			best_test_thresholds[subtask] = t
    # 			best_test_F1s[subtask] = test_F1

    # 	logging.info(f"Subtask:{subtask:>15}")
    # 	log_list(test_subtasks_t_F1_P_Rs[subtask])
    # 	logging.info(f"Best Test Threshold for subtask: {best_test_thresholds[subtask]}\t Best test F1: {best_test_F1s[subtask]}")

    for subtask in model.subtasks:
        dev_subtask_data = dev_subtasks_data[subtask]
        dev_subtask_prediction_scores = dev_prediction_scores[subtask]
        for t in thresholds:
            dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                dev_subtask_data, dev_subtask_prediction_scores, THRESHOLD=t)
            dev_subtasks_t_F1_P_Rs[subtask].append(
                (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP,
                 dev_FN))
            if dev_F1 > best_dev_F1s[subtask]:
                best_dev_thresholds[subtask] = t
                best_dev_F1s[subtask] = dev_F1

        logging.info(f"Subtask:{subtask:>15}")
        log_list(dev_subtasks_t_F1_P_Rs[subtask])
        logging.info(
            f"Best Dev Threshold for subtask: {best_dev_thresholds[subtask]}\t Best dev F1: {best_dev_F1s[subtask]}"
        )

    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_dev_thresholds
    results["best_dev_F1s"] = best_dev_F1s
    results["dev_t_F1_P_Rs"] = dev_subtasks_t_F1_P_Rs

    # Evaluate on Test
    logging.info("Testing on test dataset")
    # test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task)

    predicted_labels, prediction_scores, gold_labels = make_predictions_on_dataset(
        test_dataloader, model, device, args.task)

    # Test
    for subtask in model.subtasks:
        logging.info(f"Testing the trained classifier on subtask: {subtask}")
        # print(len(test_dataloader))
        # print(len(prediction_scores[subtask]))
        # print(len(test_subtasks_data[subtask]))
        results[subtask] = dict()
        cm = metrics.confusion_matrix(gold_labels[subtask],
                                      predicted_labels[subtask])
        classification_report = metrics.classification_report(
            gold_labels[subtask], predicted_labels[subtask], output_dict=True)
        logging.info(cm)
        logging.info(
            metrics.classification_report(gold_labels[subtask],
                                          predicted_labels[subtask]))
        results[subtask]["CM"] = cm.tolist(
        )  # Storing it as list of lists instead of numpy.ndarray
        results[subtask]["Classification Report"] = classification_report

        # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
        EM_score, F1_score, total = get_raw_scores(test_subtasks_data[subtask],
                                                   prediction_scores[subtask])
        logging.info("Word overlap based SQuAD evaluation style metrics:")
        logging.info(f"Total number of cases: {total}")
        logging.info(f"EM_score: {EM_score}")
        logging.info(f"F1_score: {F1_score}")
        results[subtask]["SQuAD_EM"] = EM_score
        results[subtask]["SQuAD_F1"] = F1_score
        results[subtask]["SQuAD_total"] = total
        pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            positive_only=True)
        logging.info(f"Total number of Positive cases: {pos_total}")
        logging.info(f"Pos. EM_score: {pos_EM_score}")
        logging.info(f"Pos. F1_score: {pos_F1_score}")
        results[subtask]["SQuAD_Pos. EM"] = pos_EM_score
        results[subtask]["SQuAD_Pos. F1"] = pos_F1_score
        results[subtask]["SQuAD_Pos. EM_F1_total"] = pos_total

        # New evaluation suggested by Alan
        F1, P, R, TP, FP, FN = get_TP_FP_FN(
            test_subtasks_data[subtask],
            prediction_scores[subtask],
            THRESHOLD=best_dev_thresholds[subtask])
        logging.info("New evaluation scores:")
        logging.info(f"F1: {F1}")
        logging.info(f"Precision: {P}")
        logging.info(f"Recall: {R}")
        logging.info(f"True Positive: {TP}")
        logging.info(f"False Positive: {FP}")
        logging.info(f"False Negative: {FN}")
        results[subtask]["F1"] = F1
        results[subtask]["P"] = P
        results[subtask]["R"] = R
        results[subtask]["TP"] = TP
        results[subtask]["FP"] = FP
        results[subtask]["FN"] = FN
        N = TP + FN
        results[subtask]["N"] = N

        # # Top predictions in the Test case
        # prediction_scores[subtask] = np.array(prediction_scores[subtask])
        # sorted_prediction_ids = np.argsort(-prediction_scores[subtask])
        # K = 200
        # logging.info("Top {} predictions:".format(K))
        # logging.info("\t".join(["Tweet", "BERT model input", "candidate chunk", "prediction score", "predicted label", "gold label", "gold chunks"]))
        # for i in range(K):
        # 	instance_id = sorted_prediction_ids[i]
        # 	# text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
        # 	tweet = test_subtasks_data[subtask][instance_id][0].replace("\n", " ")
        # 	chunk = test_subtasks_data[subtask][instance_id][1]
        # 	tokenized_tweet_with_masked_chunk = test_subtasks_data[subtask][instance_id][6]
        # 	if chunk in ["AUTHOR OF THE TWEET", "NEAR AUTHOR OF THE TWEET"]:
        # 		# First element of the text will be considered as AUTHOR OF THE TWEET or NEAR AUTHOR OF THE TWEET
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> </E>")
        # 		# print(tokenized_tweet_with_masked_chunk)
        # 		# print(bert_model_input_text)
        # 		# exit()
        # 	else:
        # 		bert_model_input_text = tokenized_tweet_with_masked_chunk.replace(Q_TOKEN, "<E> " + chunk + " </E>")
        # 	list_to_print = [tweet, bert_model_input_text, chunk, str(prediction_scores[subtask][instance_id]), str(predicted_labels[subtask][instance_id]), str(test_subtasks_data[subtask][instance_id][-1]), str(test_subtasks_data[subtask][instance_id][-2])]
        # 	logging.info("\t".join(list_to_print))

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
Exemplo n.º 4
0
def make_instances_from_dataset(dataset):
    # Create instances for all each task.
    # we will store instances for each task separately in a dictionary
    task_instances_dict = dict()

    # All the questions with interesting annotations start with prefix "part2-" and end with suffix ".Response"
    # Extract all the interesting questions' annotation keys and their corresponding question-tags
    question_keys_and_tags = list(
    )  # list of tuples of the format (<tag>, <dict-key>)
    # Extract the keys and tags from first annotation in the dataset
    dummy_annotation = dataset[0]['annotation']
    for key in dummy_annotation.keys():
        if key.startswith("part2-") and key.endswith(".Response"):
            question_tag = key.replace("part2-", "").replace(".Response", "")
            question_keys_and_tags.append((question_tag, key))
    # Sort the question keys to have a fixed ordering
    question_keys_and_tags.sort(key=lambda tup: tup[0])
    # print(question_keys_and_tags)
    # exit()
    question_tags = [
        question_tag for question_tag, question_key in question_keys_and_tags
    ]
    question_keys = [
        question_key for question_tag, question_key in question_keys_and_tags
    ]
    if "gender" in question_tags:
        # Update the question_keys_and_tags
        gender_index = question_tags.index("gender")
        question_tags[gender_index] = "gender_female"
        question_tags.insert(gender_index, "gender_male")
        question_keys.insert(gender_index, question_keys[gender_index])
        question_keys_and_tags = list(zip(question_tags, question_keys))

    task_instances_dict = {
        question_tag: list()
        for question_tag, question_key in question_keys_and_tags
    }
    # Dictionary to store total statistics of labels for each question_tag
    gold_labels_stats = {
        question_tag: dict()
        for question_tag, question_key in question_keys_and_tags
    }
    # Dictionary to store unique tweets for each gold tag within each question tag
    gold_labels_unique_tweets = {
        question_tag: dict()
        for question_tag, question_key in question_keys_and_tags
    }
    skipped_chunks = 0
    ignore_ones = []
    for annotated_data in dataset:
        # We will take one annotation and generate who instances based on the chunks
        id = annotated_data['id']
        annotation = annotated_data['annotation']
        text = annotated_data['text'].strip()
        candidate_chunks_offsets = annotated_data['candidate_chunks_offsets']
        candidate_chunks_from_text = [
            text[c[0]:c[1]] for c in candidate_chunks_offsets
        ]
        # print(annotated_data)
        tags = annotated_data['tags']
        tweet_tokens = get_tweet_tokens_from_tags(tags)
        # logging.info(f"Text:{text}")
        # logging.info(f"Tokenized text:{tweet_tokens}")
        # Checking if the tokenized tweet is equal to the tweet without spaces
        try:
            assert re.sub("\s+", "", text) == re.sub("\s+", "", tweet_tokens)
        except AssertionError:
            logging.error(f"Tweet and tokenized tweets don't match in id:{id}")
            text_without_spaces = re.sub('\s+', '', text)
            logging.error(f"Tweets without spaces: {text_without_spaces}")
            tweet_tokens_without_spaces = re.sub('\s+', '', tweet_tokens)
            logging.error(
                f"Tokens without spaces: {tweet_tokens_without_spaces}")
            exit()
        tweet_tokens = tweet_tokens.split()
        tweet_tokens_char_mapping = find_text_to_tweet_tokens_mapping(
            text, tweet_tokens)
        # NOTE: tweet_tokens_char_mapping is a list of list where each inner list maps every tweet_token's char to the original text char position
        # print(tweet_tokens_char_mapping)
        # Get candidate_chunk's offsets in terms of tweet_tokens
        candidate_chunks_offsets_from_tweet_tokens = list()
        ignore_flags = list()
        for chunk_start_idx, chunk_end_idx in candidate_chunks_offsets:
            ignore_flag = False
            # Find the tweet_token id in which the chunk_start_idx belongs
            chunk_start_token_idx = None
            chunk_end_token_idx = None
            for token_idx, tweet_token_char_mapping in enumerate(
                    tweet_tokens_char_mapping):
                if chunk_start_idx in tweet_token_char_mapping:
                    # if chunk start offset is the last character of the token
                    chunk_start_token_idx = token_idx
                if (chunk_end_idx - 1) in tweet_token_char_mapping:
                    # if chunk end offset is the last character of the token
                    chunk_end_token_idx = token_idx + 1
            if chunk_start_token_idx == None or chunk_end_token_idx == None:
                logging.error(
                    f"Tweet id:{id}\nCouldn't find chunk tokens for chunk offsets [{chunk_start_idx}, {chunk_end_idx}]:{text[chunk_start_idx:chunk_end_idx]};"
                )
                logging.error(
                    f"Found chunk start and end token idx [{chunk_start_token_idx}, {chunk_end_token_idx}]"
                )
                # logging.error(f"tweet_tokens_char_mapping len {len(tweet_tokens_char_mapping)}= {tweet_tokens_char_mapping}")
                logging.error(f"Ignoring this chunk")
                ignore_flag = True
                #### DOCUMENT IGNORE THINGS
                ignore_info = {}
                ignore_info['id'] = id
                ignore_info['problem_chunk_id'] = [
                    chunk_start_idx, chunk_end_idx
                ]
                ignore_info['problem_chunk_text'] = text[
                    chunk_start_idx:chunk_end_idx]
                ignore_info['whole_mapping'] = []
                for i in tweet_tokens_char_mapping:
                    if len(i) == 1:
                        ignore_info['whole_mapping'].append([i, text[i[0]]])
                    else:
                        ignore_info['whole_mapping'].append(
                            [i, text[i[0]:i[-1] + 1]])
                ignore_info['flag'] = 'NOT_FIND_ERROR'
                ignore_ones.append(ignore_info)
            ignore_flags.append(ignore_flag)
            candidate_chunks_offsets_from_tweet_tokens.append(
                (chunk_start_token_idx, chunk_end_token_idx))
        candidate_chunks_from_tokens = [
            ' '.join(tweet_tokens[c[0]:c[1]])
            for c in candidate_chunks_offsets_from_tweet_tokens
        ]

        # TODO: Verify if the candidate_chunks from tokens and from text are the same
        """
		for chunk_text, chunk_token, ignore_flag in zip(candidate_chunks_from_text, candidate_chunks_from_tokens, ignore_flags):
			if ignore_flag:
				continue
			try:
				assert re.sub("\s+", "", chunk_text) == re.sub("\s+", "", chunk_token)
			except AssertionError:
				logging.error(f"Chunk and text is not matching the chunk in tokenized tweet")
				chunk_text_without_spaces = re.sub("\s+", "", chunk_text)
				chunk_token_without_spaces = re.sub("\s+", "", chunk_token)
				logging.error(f"Chunk from text without spaces: {chunk_text_without_spaces}")
				logging.error(f"Chunk from tokens without spaces: {chunk_token_without_spaces}")
				exit()
		"""
        # Update the list
        candidate_chunks_from_text = [
            e
            for ignore_flag, e in zip(ignore_flags, candidate_chunks_from_text)
            if not ignore_flag
        ]
        candidate_chunks_from_tokens = [
            e for ignore_flag, e in zip(
                ignore_flags, candidate_chunks_from_tokens) if not ignore_flag
        ]
        candidate_chunks_offsets = [
            e for ignore_flag, e in zip(ignore_flags, candidate_chunks_offsets)
            if not ignore_flag
        ]
        candidate_chunks_offsets_from_tweet_tokens = [
            e for ignore_flag, e in zip(
                ignore_flags, candidate_chunks_offsets_from_tweet_tokens)
            if not ignore_flag
        ]

        # print_list(candidate_chunks_from_text)
        # print_list(candidate_chunks_from_tokens)
        # Convert annotation to token_idxs from tweet_char_offsets
        # First get the mapping from chunk_char_offsets to chunk_token_idxs
        chunk_char_offsets_to_token_idxs_mapping = {
            (offset[0], offset[1]): (c[0], c[1])
            for offset, c in zip(candidate_chunks_offsets,
                                 candidate_chunks_offsets_from_tweet_tokens)
        }
        # print(chunk_char_offsets_to_token_idxs_mapping)
        annotation_tweet_tokens = dict()
        for key, value in annotation.items():
            # print(key, value)
            if value == "NO_CONSENSUS":
                new_assignments = ["Not Specified"]
            else:
                new_assignments = list()
                for assignment in value:
                    if type(assignment) == list:
                        # get the candidate_chunk from tweet_tokens
                        # print(chunk_char_offsets_to_token_idxs_mapping.keys())
                        gold_chunk_token_idxs = chunk_char_offsets_to_token_idxs_mapping[
                            tuple(assignment)]
                        new_assignment = ' '.join(tweet_tokens[
                            gold_chunk_token_idxs[0]:gold_chunk_token_idxs[1]])
                        new_assignments.append(new_assignment)
                    else:
                        new_assignments.append(assignment)
            annotation_tweet_tokens[key] = new_assignments

        # print(annotation)
        # print(annotation_tweet_tokens)
        # print(question_keys_and_tags)
        # exit()
        # change the URLs to special URL tag
        # tweet_tokens = [URL_TOKEN if e.startswith("http") or 'twitter.com' in e or e.startswith('www.') else e for e in tweet_tokens]
        final_tweet_tokens = [
            URL_TOKEN if e.startswith("http") or 'twitter.com' in e
            or e.startswith('www.') else e for e in tweet_tokens
        ]
        final_candidate_chunks_with_token_id = [
            (f"{c[0]}_{c[1]}", ' '.join(tweet_tokens[c[0]:c[1]]), c)
            for c in candidate_chunks_offsets_from_tweet_tokens
        ]

        for question_tag, question_key in question_keys_and_tags:

            if question_tag in [
                    "name", "close_contact", "who_cure", "opinion"
            ]:
                # add "AUTHOR OF THE TWEET" as a candidate chunk
                final_candidate_chunks_with_token_id.append(
                    ["author_chunk", "AUTHOR OF THE TWEET", [0, 0]])
                # print(final_candidate_chunks_with_token_id)
                # exit()
            elif question_tag in ["where", "recent_travel"]:
                # add "NEAR AUTHOR OF THE TWEET" as a candidate chunk
                final_candidate_chunks_with_token_id.append(
                    ["near_author_chunk", "AUTHOR OF THE TWEET", [0, 0]])

            # If there are more then one candidate slot with the same candidate chunk then simply keep the first occurrence. Remove the rest.
            current_candidate_chunks = set()
            for candidate_chunk_with_id in final_candidate_chunks_with_token_id:
                candidate_chunk_id = candidate_chunk_with_id[0]
                candidate_chunk = candidate_chunk_with_id[1]

                if candidate_chunk.lower() == 'coronavirus':
                    continue

                chunk_start_id = candidate_chunk_with_id[2][0]
                chunk_start_text_id = tweet_tokens_char_mapping[
                    chunk_start_id][0]
                chunk_end_id = candidate_chunk_with_id[2][1]
                # print(len(tweet_tokens_char_mapping), len(tweet_tokens), chunk_start_id, chunk_end_id)
                chunk_end_text_id = tweet_tokens_char_mapping[chunk_end_id -
                                                              1][-1] + 1

                if candidate_chunk == "AUTHOR OF THE TWEET":
                    # No need to verify or fix this candidate_chunk
                    # print("VERIFY if chunk coming here!")
                    # exit()
                    pass
                else:
                    if chunk_end_id > len(tweet_tokens):
                        # Incorrect chunk end id. Skip this chunk
                        continue
                    candidate_chunk = ' '.join(
                        final_tweet_tokens[chunk_start_id:chunk_end_id])

                if candidate_chunk in current_candidate_chunks:
                    # Skip this chunk. Already processed before
                    skipped_chunks += 1
                    continue
                else:
                    # Add to the known list and keep going
                    current_candidate_chunks.add(candidate_chunk)
                # assert candidate_chunk == text[chunk_start_text_id:chunk_end_text_id+1]

                # Find gold labels for the current question and candidate chunk
                if question_tag in [
                        "relation", "gender_male", "gender_female", "believe",
                        "binary-relation", "binary-symptoms", "symptoms",
                        "opinion"
                ]:
                    # If the question is a yes/no question. It is for the name candidate chunk
                    special_tagged_chunks = get_tagged_label_for_key_from_annotation(
                        question_key, annotation_tweet_tokens)
                    try:
                        assert len(special_tagged_chunks) == 1
                    except AssertionError:
                        logging.error(
                            f"for question_tag {question_tag} the special_tagged_chunks = {special_tagged_chunks}"
                        )
                        exit()
                    tagged_label = special_tagged_chunks[0]
                    if tagged_label == "No":
                        tagged_label = "Not Specified"
                    if question_tag in ["gender_male", "gender_female"]:
                        gender = "Male" if question_tag == "gender_male" else "Female"
                        if gender == tagged_label:
                            special_question_label = get_label_from_tagged_label(
                                tagged_label)
                        else:
                            special_question_label = 0
                    else:
                        special_question_label = get_label_from_tagged_label(
                            tagged_label)

                    if question_tag == "opinion":
                        # question_label, tagged_chunks = get_label_for_key_from_annotation("part2-who_cure.Response", annotation_tweet_tokens, candidate_chunk)
                        tagged_chunks = []
                        if candidate_chunk == "AUTHOR OF THE TWEET":
                            question_label = 1
                            tagged_chunks.append("AUTHOR OF THE TWEET")
                        else:
                            question_label = 0
                    else:
                        question_label, tagged_chunks = get_label_for_key_from_annotation(
                            "part2-name.Response", annotation_tweet_tokens,
                            candidate_chunk)
                    question_label = question_label & special_question_label
                    if question_label == 0:
                        tagged_chunks = []
                else:
                    question_label, tagged_chunks = get_label_for_key_from_annotation(
                        question_key, annotation_tweet_tokens, candidate_chunk)
                    # if question_tag == "close_contact" and question_label == 1:
                    # 	print(candidate_chunk, annotation_tweet_tokens[question_key], question_label)

                # Add instance
                tokenized_tweet = ' '.join(final_tweet_tokens)
                # text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
                task_instances_dict[question_tag].append(
                    (text, candidate_chunk, candidate_chunk_id,
                     chunk_start_text_id, chunk_end_text_id, tokenized_tweet,
                     ' '.join(final_tweet_tokens[:chunk_start_id] + [Q_TOKEN] +
                              final_tweet_tokens[chunk_end_id:]),
                     tagged_chunks, question_label))
                # Update statistics for data analysis
                # if (tagged_chunks and len(tagged_chunks) == 1 and tagged_chunks[0] == "Not Specified") or question_label == 0:
                gold_labels_stats[question_tag].setdefault(question_label, 0)
                gold_labels_stats[question_tag][question_label] += 1
                gold_labels_unique_tweets[question_tag].setdefault(
                    question_label, set())
                gold_labels_unique_tweets[question_tag][question_label].add(
                    tokenized_tweet)
    logging.info(
        f"Total skipped chunks:{skipped_chunks}\t n_question tags:{len(question_keys_and_tags)}"
    )

    # Convert gold_labels_unique_tweets from set of tweets to counts

    for question_tag, question_key in question_keys_and_tags:
        label_unique_tweets = gold_labels_unique_tweets[question_tag]
        label_unique_tweets_counts = dict()
        for label, tweets in label_unique_tweets.items():
            label_unique_tweets_counts[label] = len(tweets)
        gold_labels_unique_tweets[question_tag] = label_unique_tweets_counts

    # Log the label-wise total statistics
    logging.info("Gold label instances statistics:")
    log_list(gold_labels_stats.items())
    logging.info("Gold label tweets statistics:")
    log_list(gold_labels_unique_tweets.items())
    tag_statistics = (gold_labels_stats, gold_labels_unique_tweets)

    # TODO: return instances header to save in pickle for later
    # TODO: Think of somehow saving the data statistics somewhere. Maybe save that in pickle as well
    question_tag_gold_chunks = [qt + "_gold_chunks" for qt in question_tags]
    question_tag_gold_labels = [qt + "_label" for qt in question_tags]
    return task_instances_dict, tag_statistics, question_keys_and_tags
def main():
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data = extract_instances_for_current_subtask(task_instances_dict,
                                                 args.sub_task)
    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data, test_data = split_instances_in_train_dev_test(data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_train_size, neg_train_size = log_data_statistics(
        train_data)
    logging.info("Dev Data:")
    total_dev_size, pos_dev_size, neg_dev_size = log_data_statistics(dev_data)
    logging.info("Test Data:")
    total_test_size, pos_test_size, neg_test_size = log_data_statistics(
        test_data)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_train_size,
        "neg": neg_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_dev_size,
        "neg": neg_dev_size
    }
    model_config["test_data"] = {
        "size": total_test_size,
        "pos": pos_test_size,
        "neg": neg_test_size
    }

    # Extract n-gram features from the train data
    # Returned ngrams will be dict of dict
    # TODO: update the feature extractor
    feature2i, i2feature = create_ngram_features_from(train_data)
    logging.info(
        f"Total number of features extracted from train = {len(feature2i)}, {len(i2feature)}"
    )
    model_config["features"] = {"size": len(feature2i)}

    # Extract Feature vectors and labels from train and test data
    train_X, train_Y = convert_data_to_feature_vector_and_labels(
        train_data, feature2i)
    dev_X, dev_Y = convert_data_to_feature_vector_and_labels(
        dev_data, feature2i)
    test_X, test_Y = convert_data_to_feature_vector_and_labels(
        test_data, feature2i)
    logging.info(
        f"Train Data Features = {train_X.shape} and Labels = {len(train_Y)}")
    logging.info(
        f"Dev Data Features = {dev_X.shape} and Labels = {len(dev_Y)}")
    logging.info(
        f"Test Data Features = {test_X.shape} and Labels = {len(test_Y)}")
    model_config["train_data"]["features_shape"] = train_X.shape
    model_config["train_data"]["labels_shape"] = len(train_Y)
    model_config["dev_data"]["features_shape"] = dev_X.shape
    model_config["dev_data"]["labels_shape"] = len(dev_Y)
    model_config["test_data"]["features_shape"] = test_X.shape
    model_config["test_data"]["labels_shape"] = len(test_Y)

    # Train logistic regression classifier
    logging.info("Training the Logistic Regression classifier")
    lr = LogisticRegression(solver='lbfgs')
    lr.fit(train_X, train_Y)
    model_config["model"] = "LogisticRegression(solver='lbfgs')"

    # Find best threshold based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

    dev_prediction_probs = lr.predict_proba(dev_X)[:, 1]
    dev_t_F1_P_Rs = list()
    best_threshold_based_on_F1 = 0.5
    best_dev_F1 = 0.0
    for t in thresholds:
        dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
            dev_data, dev_prediction_probs, THRESHOLD=t)
        dev_t_F1_P_Rs.append(
            (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP, dev_FN))
        if dev_F1 > best_dev_F1:
            best_threshold_based_on_F1 = t
            best_dev_F1 = dev_F1
    log_list(dev_t_F1_P_Rs)
    logging.info(
        f"Best Threshold: {best_threshold_based_on_F1}\t Best dev F1: {best_dev_F1}"
    )
    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_threshold_based_on_F1
    results["best_dev_F1"] = best_dev_F1
    results["dev_t_F1_P_Rs"] = dev_t_F1_P_Rs
    # y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool)

    # Test
    logging.info("Testing the trained classifier")
    predictions = lr.predict(test_X)
    probs = lr.predict_proba(test_X)
    test_Y_prediction_probs = probs[:, 1]
    cm = metrics.confusion_matrix(test_Y, predictions)
    classification_report = metrics.classification_report(test_Y,
                                                          predictions,
                                                          output_dict=True)
    logging.info(cm)
    logging.info(metrics.classification_report(test_Y, predictions))
    results["CM"] = cm.tolist(
    )  # Storing it as list of lists instead of numpy.ndarray
    results["Classification Report"] = classification_report

    # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
    EM_score, F1_score, total = get_raw_scores(test_data,
                                               test_Y_prediction_probs)
    logging.info("Word overlap based SQuAD evaluation style metrics:")
    logging.info(f"Total number of cases: {total}")
    logging.info(f"EM_score: {EM_score}")
    logging.info(f"F1_score: {F1_score}")
    results["SQuAD_EM"] = EM_score
    results["SQuAD_F1"] = F1_score
    results["SQuAD_total"] = total
    pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
        test_data, test_Y_prediction_probs, positive_only=True)
    logging.info(f"Total number of Positive cases: {pos_total}")
    logging.info(f"Pos. EM_score: {pos_EM_score}")
    logging.info(f"Pos. F1_score: {pos_F1_score}")
    results["SQuAD_Pos. EM"] = pos_EM_score
    results["SQuAD_Pos. F1"] = pos_F1_score
    results["SQuAD_Pos. EM_F1_total"] = pos_total

    # New evaluation suggested by Alan
    F1, P, R, TP, FP, FN = get_TP_FP_FN(test_data,
                                        test_Y_prediction_probs,
                                        THRESHOLD=best_threshold_based_on_F1)
    logging.info("New evaluation scores:")
    logging.info(f"F1: {F1}")
    logging.info(f"Precision: {P}")
    logging.info(f"Recall: {R}")
    logging.info(f"True Positive: {TP}")
    logging.info(f"False Positive: {FP}")
    logging.info(f"False Negative: {FN}")
    results["F1"] = F1
    results["P"] = P
    results["R"] = R
    results["TP"] = TP
    results["FP"] = FP
    results["FN"] = FN
    N = TP + FN
    results["N"] = N

    # Top predictions in the Test case
    sorted_prediction_ids = np.argsort(-test_Y_prediction_probs)
    K = 30
    logging.info("Top {} predictions:".format(K))
    for i in range(K):
        instance_id = sorted_prediction_ids[i]
        # text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
        list_to_print = [
            test_data[instance_id][0], test_data[instance_id][6],
            test_data[instance_id][1],
            str(test_Y_prediction_probs[instance_id]),
            str(test_Y[instance_id]),
            str(test_data[instance_id][-1]),
            str(test_data[instance_id][-2])
        ]
        logging.info("\t".join(list_to_print))

    # Top feature analysis
    coefs = lr.coef_[0]
    K = 10
    sorted_feature_ids = np.argsort(-coefs)
    logging.info("Top {} features:".format(K))
    for i in range(K):
        feature_id = sorted_feature_ids[i]
        logging.info(f"{i2feature[feature_id]}\t{coefs[feature_id]}")

    # Plot the precision recall curve
    save_figure_file = os.path.join(args.output_dir,
                                    "Precision Recall Curve.png")
    logging.info(f"Saving precision recall curve at {save_figure_file}")
    disp = plot_precision_recall_curve(lr, test_X, test_Y)
    disp.ax_.set_title('2-class Precision-Recall curve')
    disp.ax_.figure.savefig(save_figure_file)

    # Save the model and features in pickle file
    model_and_features_save_file = os.path.join(args.output_dir,
                                                "model_and_features.pkl")
    logging.info(
        f"Saving LR model and features at {model_and_features_save_file}")
    save_in_pickle((lr, feature2i, i2feature), model_and_features_save_file)

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)
Exemplo n.º 6
0
def make_instances_from_dataset(dataset):
    # Create instances for all each task.
    # we will store instances for each task separately in a dictionary
    task_instances_dict = dict()

    # All the questions with interesting annotations start with prefix "part2-" and end with suffix ".Response"
    # Extract all the interesting questions' annotation keys and their corresponding question-tags
    question_keys_and_tags = list(
    )  # list of tuples of the format (<tag>, <dict-key>)
    # Extract the keys and tags from first annotation in the dataset
    dummy_annotation = dataset[0]['consensus_annotation']
    for key in dummy_annotation.keys():
        if key.startswith("part2-") and key.endswith(".Response"):
            question_tag = key.replace("part2-", "").replace(".Response", "")
            question_keys_and_tags.append((question_tag, key))
    # Sort the question keys to have a fixed ordering
    question_keys_and_tags.sort(key=lambda tup: tup[0])
    # print(question_keys_and_tags)
    # exit()
    question_tags = [
        question_tag for question_tag, question_key in question_keys_and_tags
    ]
    question_keys = [
        question_key for question_tag, question_key in question_keys_and_tags
    ]
    if "gender" in question_tags:
        # Update the question_keys_and_tags
        gender_index = question_tags.index("gender")
        question_tags[gender_index] = "gender_female"
        question_tags.insert(gender_index, "gender_male")
        question_keys.insert(gender_index, question_keys[gender_index])
        question_keys_and_tags = list(zip(question_tags, question_keys))

    task_instances_dict = {
        question_tag: list()
        for question_tag, question_key in question_keys_and_tags
    }
    # Dictionary to store total statistics of labels for each question_tag
    gold_labels_stats = {
        question_tag: dict()
        for question_tag, question_key in question_keys_and_tags
    }
    # Dictionary to store unique tweets for each gold tag within each question tag
    gold_labels_unique_tweets = {
        question_tag: dict()
        for question_tag, question_key in question_keys_and_tags
    }
    skipped_chunks = 0
    for annotated_data in dataset:
        # We will take one annotation and generate who instances based on the chunks
        tweet_tokens = annotated_data['tokenization']
        # print(annotated_data)
        # exit()
        text = annotated_data['text'].strip()
        # print(text)
        # print(tweet_tokens)
        tweet_tokens_char_mapping = find_text_to_tweet_tokens_mapping(
            text, tweet_tokens)
        # print(annotated_data.keys())
        # exit()
        # change the URLs to special URL tag
        # tweet_tokens = [URL_TOKEN if e.startswith("http") or 'twitter.com' in e or e.startswith('www.') else e for e in tweet_tokens]
        final_tweet_tokens = [
            URL_TOKEN if e.startswith("http") or 'twitter.com' in e
            or e.startswith('www.') else e for e in tweet_tokens
        ]
        tags = annotated_data['tags']
        candidate_chunks = annotated_data['candidate_chunks']
        extracted_chunks_NP = annotated_data['extracted_chunks_NP']
        annotation = annotated_data['consensus_annotation']
        candidate_chunks_with_id = annotated_data['candidate_chunks_with_id']

        for question_tag, question_key in question_keys_and_tags:

            if question_tag in [
                    "name", "close_contact", "who_cure", "opinion"
            ]:
                # add "AUTHOR OF THE TWEET" as a candidate chunk
                candidate_chunks_with_id.append([
                    "author_chunk", "AUTHOR OF THE TWEET", [0, 0],
                    "author_chunk"
                ])
                # print(candidate_chunks_with_id)
                # exit()
            elif question_tag in ["where", "recent_travel"]:
                # add "NEAR AUTHOR OF THE TWEET" as a candidate chunk
                candidate_chunks_with_id.append([
                    "near_author_chunk", "NEAR AUTHOR OF THE TWEET", [0, 0],
                    "near_author_chunk"
                ])

            # If there are more then one candidate slot with the same candidate chunk then simply keep the first occurrence. Remove the rest.
            current_candidate_chunks = set()
            for candidate_chunk_with_id in candidate_chunks_with_id:
                candidate_chunk_id = candidate_chunk_with_id[0]
                candidate_chunk = candidate_chunk_with_id[1]

                if candidate_chunk.lower() == 'coronavirus':
                    continue

                chunk_start_id = candidate_chunk_with_id[2][0]
                chunk_end_id = candidate_chunk_with_id[2][1]

                if candidate_chunk in [
                        "AUTHOR OF THE TWEET", "NEAR AUTHOR OF THE TWEET"
                ]:
                    # No need to verify or fix this candidate_chunk
                    # print("VERIFY if chunk coming here!")
                    # exit()
                    pass
                else:
                    # Verify if the candidate chunk is correct and aligns with the tweet and tokens
                    if ' '.join(tweet_tokens[chunk_start_id:chunk_end_id]
                                ) != candidate_chunk:
                        # Use the one from tweet_tokens
                        logging.debug(
                            f"Prev:{candidate_chunk}||New:{' '.join(tweet_tokens[chunk_start_id:chunk_end_id])}|"
                        )
                        candidate_chunk = ' '.join(
                            tweet_tokens[chunk_start_id:chunk_end_id])
                    if chunk_end_id >= len(tweet_tokens):
                        # Incorrect chunk end id. Skip this chunk
                        continue
                    # Find chunk_start_text_id and chunk_end_text_id
                    chunk_start_text_id = tweet_tokens_char_mapping[
                        chunk_start_id][0]
                    # print(candidate_chunk)
                    # print(tweet_tokens)
                    # print(chunk_start_id, chunk_end_id, len(tweet_tokens_char_mapping), len(tweet_tokens))
                    chunk_end_text_id = tweet_tokens_char_mapping[chunk_end_id
                                                                  - 1][-1]
                    # print(text[chunk_start_text_id:chunk_end_text_id+1])
                    candidate_chunk_from_text = text[
                        chunk_start_text_id:chunk_end_text_id + 1]
                    if re.sub(r"\s+", "", candidate_chunk) != re.sub(
                            r"\s+", "",
                            text[chunk_start_text_id:chunk_end_text_id + 1]):
                        # Trusting the text of the tweet
                        logging.warn(
                            f"Conflict in given candidate chunk and tweet_text"
                        )
                        logging.warn(
                            f"Given candidate chunk = {candidate_chunk}")
                        logging.warn(
                            f"Text in tweet = {text[chunk_start_text_id:chunk_end_text_id+1]}"
                        )
                        logging.warn(f"Text in tweet used!")
                        exit()
                        candidate_chunk = text[
                            chunk_start_text_id:chunk_end_text_id + 1]
                    candidate_chunk = ' '.join(
                        final_tweet_tokens[chunk_start_id:chunk_end_id])

                if candidate_chunk in current_candidate_chunks:
                    # Skip this chunk. Already processed before
                    skipped_chunks += 1
                    continue
                else:
                    # Add to the known list and keep going
                    current_candidate_chunks.add(candidate_chunk)
                # assert candidate_chunk == text[chunk_start_text_id:chunk_end_text_id+1]

                # Find gold labels for the current question and candidate chunk
                if question_tag in [
                        "relation", "gender_male", "gender_female", "believe",
                        "binary-relation", "binary-symptoms", "symptoms",
                        "opinion"
                ]:
                    # If the question is a yes/no question. It is for the name candidate chunk
                    special_tagged_chunks = get_tagged_label_for_key_from_annotation(
                        question_key, annotation)
                    assert len(special_tagged_chunks) == 1
                    tagged_label = special_tagged_chunks[0]
                    if tagged_label == "No":
                        tagged_label = "Not Specified"
                    if question_tag in ["gender_male", "gender_female"]:
                        gender = "Male" if question_tag == "gender_male" else "Female"
                        if gender == tagged_label:
                            special_question_label = get_label_from_tagged_label(
                                tagged_label)
                        else:
                            special_question_label = 0
                    else:
                        special_question_label = get_label_from_tagged_label(
                            tagged_label)

                    if question_tag == "opinion":
                        # question_label, tagged_chunks = get_label_for_key_from_annotation("part2-who_cure.Response", annotation, candidate_chunk)
                        tagged_chunks = []
                        if candidate_chunk == "AUTHOR OF THE TWEET":
                            question_label = 1
                            tagged_chunks.append("AUTHOR OF THE TWEET")
                        else:
                            question_label = 0
                    else:
                        question_label, tagged_chunks = get_label_for_key_from_annotation(
                            "part2-name.Response", annotation, candidate_chunk)
                    question_label = question_label & special_question_label
                    if question_label == 0:
                        tagged_chunks = []
                else:
                    question_label, tagged_chunks = get_label_for_key_from_annotation(
                        question_key, annotation, candidate_chunk)
                    # if question_tag == "close_contact" and question_label == 1:
                    # 	print(candidate_chunk, annotation[question_key], question_label)

                # Add instance
                tokenized_tweet = ' '.join(final_tweet_tokens)
                # text :: candidate_chunk :: candidate_chunk_id :: chunk_start_text_id :: chunk_end_text_id :: tokenized_tweet :: tokenized_tweet_with_masked_q_token :: tagged_chunks :: question_label
                task_instances_dict[question_tag].append(
                    (text, candidate_chunk, candidate_chunk_id,
                     chunk_start_text_id, chunk_end_text_id, tokenized_tweet,
                     ' '.join(final_tweet_tokens[:chunk_start_id] + [Q_TOKEN] +
                              final_tweet_tokens[chunk_end_id:]),
                     tagged_chunks, question_label))
                # Update statistics for data analysis
                # if (tagged_chunks and len(tagged_chunks) == 1 and tagged_chunks[0] == "Not Specified") or question_label == 0:
                gold_labels_stats[question_tag].setdefault(question_label, 0)
                gold_labels_stats[question_tag][question_label] += 1
                gold_labels_unique_tweets[question_tag].setdefault(
                    question_label, set())
                gold_labels_unique_tweets[question_tag][question_label].add(
                    tokenized_tweet)
    logging.info(
        f"Total skipped chunks:{skipped_chunks}\t n_question tags:{len(question_keys_and_tags)}"
    )

    # Convert gold_labels_unique_tweets from set of tweets to counts

    for question_tag, question_key in question_keys_and_tags:
        label_unique_tweets = gold_labels_unique_tweets[question_tag]
        label_unique_tweets_counts = dict()
        for label, tweets in label_unique_tweets.items():
            label_unique_tweets_counts[label] = len(tweets)
        gold_labels_unique_tweets[question_tag] = label_unique_tweets_counts

    # Log the label-wise total statistics
    logging.info("Gold label instances statistics:")
    log_list(gold_labels_stats.items())
    logging.info("Gold label tweets statistics:")
    log_list(gold_labels_unique_tweets.items())
    tag_statistics = (gold_labels_stats, gold_labels_unique_tweets)

    # TODO: return instances header to save in pickle for later
    # TODO: Think of somehow saving the data statistics somewhere. Maybe save that in pickle as well
    question_tag_gold_chunks = [qt + "_gold_chunks" for qt in question_tags]
    question_tag_gold_labels = [qt + "_label" for qt in question_tags]
    return task_instances_dict, tag_statistics, question_keys_and_tags
Exemplo n.º 7
0
def main():
    # Read all the data instances
    task_instances_dict, tag_statistics, question_keys_and_tags = load_from_pickle(
        args.data_file)
    data, subtasks_list = get_multitask_instances_for_valid_tasks(
        task_instances_dict, tag_statistics)
    data = add_marker_for_loss_ignore(
        data, 1.0 if args.loss_for_no_consensus else 0.0)

    if args.retrain:
        if args.large_bert:
            model_name = "bert-large-cased"
        elif args.covid_bert:
            model_name = "digitalepidemiologylab/covid-twitter-bert"
        else:
            model_name = "bert-base-cased"

        logging.info("Creating and training the model from '" + model_name +
                     "'")
        # Create the save_directory if not exists
        make_dir_if_not_exists(args.save_directory)

        # Initialize tokenizer and model with pretrained weights
        tokenizer = BertTokenizer.from_pretrained(model_name)
        config = BertConfig.from_pretrained(model_name)
        config.subtasks = subtasks_list
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            model_name, config=config)

        # Add new tokens in tokenizer
        new_special_tokens_dict = {
            "additional_special_tokens": ["<E>", "</E>", "<URL>", "@USER"]
        }
        tokenizer.add_special_tokens(new_special_tokens_dict)

        # Add the new embeddings in the weights
        print("Embeddings type:",
              model.bert.embeddings.word_embeddings.weight.data.type())
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        embedding_size = model.bert.embeddings.word_embeddings.weight.size(1)
        new_embeddings = torch.FloatTensor(
            len(new_special_tokens_dict["additional_special_tokens"]),
            embedding_size).uniform_(-0.1, 0.1)
        # new_embeddings = torch.FloatTensor(2, embedding_size).uniform_(-0.1, 0.1)
        print("new_embeddings shape:", new_embeddings.size())
        new_embedding_weight = torch.cat(
            (model.bert.embeddings.word_embeddings.weight.data,
             new_embeddings), 0)
        model.bert.embeddings.word_embeddings.weight.data = new_embedding_weight
        print("Embeddings shape:",
              model.bert.embeddings.word_embeddings.weight.data.size())
        # Update model config vocab size
        model.config.vocab_size = model.config.vocab_size + len(
            new_special_tokens_dict["additional_special_tokens"])
    else:
        # Load the tokenizer and model from the save_directory
        tokenizer = BertTokenizer.from_pretrained(args.save_directory)
        model = MultiTaskBertForCovidEntityClassification.from_pretrained(
            args.save_directory)
        # Load from individual state dicts
        for subtask in model.subtasks:
            model.classifiers[subtask].load_state_dict(
                torch.load(
                    os.path.join(args.save_directory,
                                 f"{subtask}_classifier.bin")))
    model.to(device)
    if args.wandb:
        wandb.watch(model)

    # Explicitly move the classifiers to device
    for subtask, classifier in model.classifiers.items():
        classifier.to(device)
    for subtask, classifier in model.context_vectors.items():
        classifier.to(device)

    entity_start_token_id = tokenizer.convert_tokens_to_ids(["<E>"])[0]
    entity_end_token_id = tokenizer.convert_tokens_to_ids(["</E>"])[0]

    logging.info(
        f"Task dataset for task: {args.task} loaded from {args.data_file}.")

    model_config = dict()
    results = dict()

    # Split the data into train, dev and test and shuffle the train segment
    train_data, dev_data = split_multitask_instances_in_train_dev(data)
    random.shuffle(train_data)  # shuffle happens in-place
    logging.info("Train Data:")
    total_train_size, pos_subtasks_train_size, neg_subtasks_train_size = log_multitask_data_statistics(
        train_data, model.subtasks)
    logging.info("Dev Data:")
    total_dev_size, pos_subtasks_dev_size, neg_subtasks_dev_size = log_multitask_data_statistics(
        dev_data, model.subtasks)
    #logging.info("Test Data:")
    #total_test_size, pos_subtasks_test_size, neg_subtasks_test_size = log_multitask_data_statistics(test_data, model.subtasks)
    logging.info("\n")
    model_config["train_data"] = {
        "size": total_train_size,
        "pos": pos_subtasks_train_size,
        "neg": neg_subtasks_train_size
    }
    model_config["dev_data"] = {
        "size": total_dev_size,
        "pos": pos_subtasks_dev_size,
        "neg": neg_subtasks_dev_size
    }
    #model_config["test_data"] = {"size":total_test_size, "pos":pos_subtasks_test_size, "neg":neg_subtasks_test_size}

    # Extract subtasks data for dev and test
    train_subtasks_data = split_data_based_on_subtasks(train_data,
                                                       model.subtasks)
    dev_subtasks_data = split_data_based_on_subtasks(dev_data, model.subtasks)
    #test_subtasks_data = split_data_based_on_subtasks(test_data, model.subtasks)

    # Load the instances into pytorch dataset
    train_dataset = COVID19TaskDataset(train_data)
    dev_dataset = COVID19TaskDataset(dev_data)
    #test_dataset = COVID19TaskDataset(test_data)
    logging.info("Loaded the datasets into Pytorch datasets")

    tokenize_collator = TokenizeCollator(tokenizer, model.subtasks,
                                         entity_start_token_id,
                                         entity_end_token_id)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=POSSIBLE_BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0,
                                  collate_fn=tokenize_collator)
    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=POSSIBLE_BATCH_SIZE,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=tokenize_collator)
    #test_dataloader = DataLoader(test_dataset, batch_size=POSSIBLE_BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=tokenize_collator)
    logging.info("Created train and test dataloaders with batch aggregation")

    # Only retrain if needed
    if args.retrain:
        optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
        logging.info("Created model optimizer")
        #if args.sentence_level_classify:
        #    args.n_epochs += 2
        epochs = args.n_epochs

        # Total number of training steps is [number of batches] x [number of epochs].
        total_steps = len(train_dataloader) * epochs

        # Create the learning rate scheduler.
        # NOTE: num_warmup_steps = 0 is the Default value in run_glue.py
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=total_steps)
        # We'll store a number of quantities such as training and validation loss, validation accuracy, and timings.
        training_stats = []
        print("\n\n\n ====== Training for task", args.task,
              "=============\n\n\n")
        logging.info(f"Initiating training loop for {args.n_epochs} epochs...")
        print(model.state_dict().keys())

        total_start_time = time.time()

        # Find the accumulation steps
        accumulation_steps = args.batch_size / POSSIBLE_BATCH_SIZE

        # Dev validation trajectory
        epoch_train_loss = list()
        train_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        dev_subtasks_validation_statistics = {
            subtask: list()
            for subtask in model.subtasks
        }
        best_dev_F1 = 0
        for epoch in range(epochs):

            logging.info(f"Initiating Epoch {epoch+1}:")

            # Reset the total loss for each epoch.
            total_train_loss = 0
            train_loss_trajectory = list()

            # Reset timer for each epoch
            start_time = time.time()
            model.train()

            dev_log_frequency = 5
            n_steps = len(train_dataloader)
            dev_steps = int(n_steps / dev_log_frequency)
            for step, batch in enumerate(train_dataloader):
                # Upload labels of each subtask to device
                for subtask in model.subtasks:
                    subtask_labels = batch["gold_labels"][subtask]
                    subtask_labels = subtask_labels.to(device)
                    batch["gold_labels"][subtask] = subtask_labels
                    batch["label_ignore_loss"][subtask] = batch[
                        "label_ignore_loss"][subtask].to(device)

                # Forward
                input_dict = {
                    "input_ids":
                    batch["input_ids"].to(device),
                    "entity_start_positions":
                    batch["entity_start_positions"].to(device),
                    "entity_end_positions":
                    batch["entity_end_positions"].to(device),
                    "labels":
                    batch["gold_labels"],
                    "label_weight":
                    batch["label_ignore_loss"]
                }

                input_ids = batch["input_ids"]
                entity_start_positions = batch["entity_start_positions"]
                gold_labels = batch["gold_labels"]
                batch_data = batch["batch_data"]
                loss, logits = model(**input_dict)

                # Accumulate loss
                total_train_loss += loss.item()

                # Backward: compute gradients
                loss.backward()

                if (step + 1) % accumulation_steps == 0:
                    # Calculate elapsed time in minutes and print loss on the tqdm bar
                    elapsed = format_time(time.time() - start_time)
                    avg_train_loss = total_train_loss / (step + 1)

                    # keep track of changing avg_train_loss
                    train_loss_trajectory.append(avg_train_loss)
                    if (step + 1) % (accumulation_steps * 20) == 0:
                        print(
                            f"Epoch:{epoch+1}|Batch:{step}/{len(train_dataloader)}|Time:{elapsed}|Avg. Loss:{avg_train_loss:.4f}|Loss:{loss.item():.4f}"
                        )

                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()

                    # Clean the model's previous gradients
                    model.zero_grad()
                    scheduler.step()

            # Calculate the average loss over all of the batches.
            avg_train_loss = total_train_loss / len(train_dataloader)

            # Perform validation with the model and log the performance
            print("\n")
            logging.info("Running Validation...")
            # Put the model in evaluation mode--the dropout layers behave differently during evaluation.
            model.eval()
            dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
                dev_dataloader, model, device, args.task + "_dev", True)

            wandb_log_dict = {"Train Loss": avg_train_loss}
            print("Dev Set:")
            collect_TP_FP_FN = {"TP": 0, "FP": 0, "FN": 0}
            for subtask in model.subtasks:
                dev_subtask_data = dev_subtasks_data[subtask]
                dev_subtask_prediction_scores = dev_prediction_scores[subtask]
                dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                    dev_subtask_data,
                    dev_subtask_prediction_scores,
                    task=subtask)
                if subtask not in IGNORE_TASKS:
                    collect_TP_FP_FN["TP"] += dev_TP
                    collect_TP_FP_FN["FP"] += dev_FP
                    collect_TP_FP_FN["FN"] += dev_FN
                else:
                    print("IGNORE: ", end="")

                print(
                    f"Subtask:{subtask:>15}\tN={dev_TP + dev_FN}\tF1={dev_F1}\tP={dev_P}\tR={dev_R}\tTP={dev_TP}\tFP={dev_FP}\tFN={dev_FN}"
                )
                dev_subtasks_validation_statistics[subtask].append(
                    (epoch + 1, step + 1, dev_TP + dev_FN, dev_F1, dev_P,
                     dev_R, dev_TP, dev_FP, dev_FN))

                wandb_log_dict["Dev_ " + subtask + "_F1"] = dev_F1
                wandb_log_dict["Dev_ " + subtask + "_P"] = dev_P
                wandb_log_dict["Dev_ " + subtask + "_R"] = dev_R

            dev_macro_P = collect_TP_FP_FN["TP"] / (collect_TP_FP_FN["TP"] +
                                                    collect_TP_FP_FN["FP"])
            dev_macro_R = collect_TP_FP_FN["TP"] / (collect_TP_FP_FN["TP"] +
                                                    collect_TP_FP_FN["FN"])
            dev_macro_F1 = (2 * dev_macro_P * dev_macro_R) / (dev_macro_P +
                                                              dev_macro_R)
            print(collect_TP_FP_FN)
            print("dev_macro_P:", dev_macro_P, "\ndev_macro_R:", dev_macro_R,
                  "\ndev_macro_F1:", dev_macro_F1, "\n")
            wandb_log_dict["Dev_macro_F1"] = dev_macro_F1
            wandb_log_dict["Dev_macro_P"] = dev_macro_P
            wandb_log_dict["Dev_macro_R"] = dev_macro_R

            if args.wandb:
                wandb.log(wandb_log_dict)

            if dev_macro_F1 > best_dev_F1:
                best_dev_F1 = dev_macro_F1
                print("NEW BEST F1:", best_dev_F1, " Saving checkpoint now.")
                torch.save(model.state_dict(), args.output_dir + "/ckpt.pth")
                #print(model.state_dict().keys())
                #model.save_pretrained(args.save_directory)
            model.train()

            training_time = format_time(time.time() - start_time)

            # Record all statistics from this epoch.
            training_stats.append({
                'epoch': epoch + 1,
                'Training Loss': avg_train_loss,
                'Training Time': training_time
            })

            # Save the loss trajectory
            epoch_train_loss.append(train_loss_trajectory)
            print("\n\n")

        logging.info(
            f"Training complete with total Train time:{format_time(time.time()- total_start_time)}"
        )
        log_list(training_stats)

        model.load_state_dict(torch.load(args.output_dir + "/ckpt.pth"))
        model.eval()
        # Save the model and the Tokenizer here:
        #logging.info(f"Saving the model and tokenizer in {args.save_directory}")
        #model.save_pretrained(args.save_directory)

        # Save each subtask classifiers weights to individual state dicts
        #for subtask, classifier in model.classifiers.items():
        #    classifier_save_file = os.path.join(args.save_directory, f"{subtask}_classifier.bin")
        #    logging.info(f"Saving the model's {subtask} classifier weights at {classifier_save_file}")
        #    torch.save(classifier.state_dict(), classifier_save_file)
        #tokenizer.save_pretrained(args.save_directory)

        # Plot the train loss trajectory in a plot
        #train_loss_trajectory_plot_file = os.path.join(args.output_dir, "train_loss_trajectory.png")
        #logging.info(f"Saving the Train loss trajectory at {train_loss_trajectory_plot_file}")
        #print(epoch_train_loss)

        # TODO: Plot the validation performance
        # Save dev_subtasks_validation_statistics
    else:
        raise
        logging.info("No training needed. Directly going to evaluation!")

    # Save the model name in the model_config file
    model_config["model"] = "MultiTaskBertForCovidEntityClassification"
    model_config["epochs"] = args.n_epochs

    # Find best threshold for each subtask based on dev set performance
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    #test_predicted_labels, test_prediction_scores, test_gold_labels = make_predictions_on_dataset(test_dataloader, model, device, args.task, True)
    dev_predicted_labels, dev_prediction_scores, dev_gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task + "_dev", True)

    best_test_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_dev_thresholds = {subtask: 0.5 for subtask in model.subtasks}
    best_test_F1s = {subtask: 0.0 for subtask in model.subtasks}
    best_dev_F1s = {subtask: 0.0 for subtask in model.subtasks}
    #test_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}
    dev_subtasks_t_F1_P_Rs = {subtask: list() for subtask in model.subtasks}

    for subtask in model.subtasks:
        dev_subtask_data = dev_subtasks_data[subtask]
        dev_subtask_prediction_scores = dev_prediction_scores[subtask]
        for t in thresholds:
            dev_F1, dev_P, dev_R, dev_TP, dev_FP, dev_FN = get_TP_FP_FN(
                dev_subtask_data,
                dev_subtask_prediction_scores,
                THRESHOLD=t,
                task=subtask)
            dev_subtasks_t_F1_P_Rs[subtask].append(
                (t, dev_F1, dev_P, dev_R, dev_TP + dev_FN, dev_TP, dev_FP,
                 dev_FN))
            if dev_F1 > best_dev_F1s[subtask]:
                best_dev_thresholds[subtask] = t
                best_dev_F1s[subtask] = dev_F1

        logging.info(f"Subtask:{subtask:>15}")
        log_list(dev_subtasks_t_F1_P_Rs[subtask])
        logging.info(
            f"Best Dev Threshold for subtask: {best_dev_thresholds[subtask]}\t Best dev F1: {best_dev_F1s[subtask]}"
        )

    # Save the best dev threshold and dev_F1 in results dict
    results["best_dev_threshold"] = best_dev_thresholds
    results["best_dev_F1s"] = best_dev_F1s
    results["dev_t_F1_P_Rs"] = dev_subtasks_t_F1_P_Rs

    # Evaluate on Test
    logging.info("Testing on eval dataset")
    predicted_labels, prediction_scores, gold_labels = make_predictions_on_dataset(
        dev_dataloader, model, device, args.task)

    # Test
    for subtask in model.subtasks:
        logging.info(f"\nTesting the trained classifier on subtask: {subtask}")

        results[subtask] = dict()
        cm = metrics.confusion_matrix(gold_labels[subtask],
                                      predicted_labels[subtask])
        classification_report = metrics.classification_report(
            gold_labels[subtask], predicted_labels[subtask], output_dict=True)
        logging.info(cm)
        logging.info(
            metrics.classification_report(gold_labels[subtask],
                                          predicted_labels[subtask]))
        results[subtask]["CM"] = cm.tolist(
        )  # Storing it as list of lists instead of numpy.ndarray
        results[subtask]["Classification Report"] = classification_report

        # SQuAD style EM and F1 evaluation for all test cases and for positive test cases (i.e. for cases where annotators had a gold annotation)
        EM_score, F1_score, total = get_raw_scores(dev_subtasks_data[subtask],
                                                   prediction_scores[subtask])
        logging.info("Word overlap based SQuAD evaluation style metrics:")
        logging.info(f"Total number of cases: {total}")
        logging.info(f"EM_score: {EM_score}")
        logging.info(f"F1_score: {F1_score}")
        results[subtask]["SQuAD_EM"] = EM_score
        results[subtask]["SQuAD_F1"] = F1_score
        results[subtask]["SQuAD_total"] = total
        pos_EM_score, pos_F1_score, pos_total = get_raw_scores(
            dev_subtasks_data[subtask],
            prediction_scores[subtask],
            positive_only=True)
        logging.info(f"Total number of Positive cases: {pos_total}")
        logging.info(f"Pos. EM_score: {pos_EM_score}")
        logging.info(f"Pos. F1_score: {pos_F1_score}")
        results[subtask]["SQuAD_Pos. EM"] = pos_EM_score
        results[subtask]["SQuAD_Pos. F1"] = pos_F1_score
        results[subtask]["SQuAD_Pos. EM_F1_total"] = pos_total

        # New evaluation suggested by Alan
        F1, P, R, TP, FP, FN = get_TP_FP_FN(
            dev_subtasks_data[subtask],
            prediction_scores[subtask],
            THRESHOLD=best_dev_thresholds[subtask],
            task=subtask)
        logging.info("New evaluation scores:")
        logging.info(f"F1: {F1}")
        logging.info(f"Precision: {P}")
        logging.info(f"Recall: {R}")
        logging.info(f"True Positive: {TP}")
        logging.info(f"False Positive: {FP}")
        logging.info(f"False Negative: {FN}")
        results[subtask]["F1"] = F1
        results[subtask]["P"] = P
        results[subtask]["R"] = R
        results[subtask]["TP"] = TP
        results[subtask]["FP"] = FP
        results[subtask]["FN"] = FN
        N = TP + FN
        results[subtask]["N"] = N

    # Save model_config and results
    model_config_file = os.path.join(args.output_dir, "model_config.json")
    results_file = os.path.join(args.output_dir, "results.json")
    logging.info(f"Saving model config at {model_config_file}")
    save_in_json(model_config, model_config_file)
    logging.info(f"Saving results at {results_file}")
    save_in_json(results, results_file)