def test(embeddingsFile, existingModel, predictionsFile,
         documents_are_sequences, sentenceCNN, charCNN, uniLSTM, useBERT):
    embeddings, vocab = event_reader.load_embeddings(embeddingsFile,
                                                     vocab_size,
                                                     word_embedding_dim)

    testSentences, testBookIndex = event_reader.prepare_annotations_from_folder(
        test_folder, documents_are_sequences, useBERT)

    testC, testX, testP, testW, testY, testL = transform_examples(
        testSentences, vocab, useBERT)

    test_generator = single_generator(testC, testX, testP, testW, testY, testL)

    test_metadata = event_reader.convert_to_index(testSentences)

    model = event_cnn(embeddings, sentenceCNN, charCNN, uniLSTM, useBERT)

    model.load_weights(existingModel)
    predictionFile = predictionsFile
    out = open(predictionFile, "w", encoding="utf-8")
    gold = []
    preds = []
    c = 0
    for step in range(len(testL)):
        batch, y = next(test_generator)

        probs = model.predict_on_batch(batch)

        _, length, _ = y.shape
        for i in range(length):
            out.write("%s\t%s\t%s\t%.20f\n" %
                      ('\t'.join([str(x) for x in test_metadata[c]]),
                       int(probs[0][i][0] > 0.5), y[0][i][0], probs[0][i][0]))

            preds.append(probs[0][i][0] >= 0.5)
            gold.append(y[0][i][0])
            c += 1

    f, p, r, correct, trials, trues = event_eval.check_f1_two_lists(
        gold, preds)

    print("precision: %.3f %s/%s" % (p, correct, trials))
    print("recall: %.3f %s/%s" % (r, correct, trues))
    print("F: %.3f" % f)

    event_eval.check_f1_two_lists(gold, preds)
    out.close()
def predict_file_batch(filename, embeddingsFile, existingModel,
                       predictionsFile, documents_are_sequences, sentenceCNN,
                       charCNN, pad, uniLSTM, useBERT):
    embeddings, vocab = event_reader.load_embeddings(embeddingsFile,
                                                     vocab_size,
                                                     word_embedding_dim)

    testSentences, testBookIndex = event_reader.prepare_annotations_from_file(
        filename,
        documents_are_sequences=documents_are_sequences,
        useBERT=useBERT)

    testC, testX, testP, testW, testY, testL = transform_examples(
        testSentences, vocab, useBERT)

    test_metadata = event_reader.convert_to_index(testSentences)

    max_sequence_length = 0
    for length in testL:
        if length > max_sequence_length:
            max_sequence_length = length

    print("max l: ", max_sequence_length)
    model = event_cnn(embeddings, sentenceCNN, charCNN, uniLSTM, useBERT)

    model.load_weights(existingModel)
    predictionFile = predictionsFile
    out = open(predictionFile, "w", encoding="utf-8")

    if pad:
        testC, testX, testP, testW, testY, testL = pad_all(
            testC, testX, testP, testW, testY, testL, max_sequence_length,
            useBERT)
        model.predict([testC, testW, testX, testP])

        probs = model.predict([testC, testW, testX, testP], batch_size=128)
        c = 0
        lastSent = None
        for step in range(len(testL)):
            for i in range(testL[step]):
                sid = test_metadata[c][1]
                if lastSent != sid and lastSent != None:
                    out.write("\n")

                w_book, w_sid, w_word, _ = test_metadata[c]
                label = "O"
                if probs[step][i][0] > 0.5:
                    label = "EVENT"
                out.write("%s\t%s\n" % (w_word, label))
                lastSent = sid
                c += 1

    else:
        test_generator = single_generator(testC, testX, testP, testW, testY,
                                          testL)

        c = 0
        for step in range(len(testL)):
            batch, y = next(test_generator)

            probs = model.predict_on_batch(batch)

            _, length, _ = y.shape

            for i in range(length):
                sid = test_metadata[c][1]

                w_book, w_sid, w_word, _ = test_metadata[c]
                label = "O"
                if probs[0][i][0] > 0.5:
                    label = "EVENT"
                out.write("%s\t%s\n" % (w_word, label))

                c += 1

            out.write("\n")

    out.close()
if __name__ == "__main__":

    outputFile = sys.argv[1]

    nlp = spacy.load('en', disable=['ner,parser'])
    nlp.remove_pipe('ner')
    nlp.remove_pipe('parser')

    train_folder = "../data/bert/train"
    dev_folder = "../data/bert/dev"
    test_folder = "../data/bert/test"

    testSentences, _ = event_reader.prepare_annotations_from_folder(
        test_folder)
    test_metadata = event_reader.convert_to_index(testSentences)

    golds = []
    preds = []

    for sentence in testSentences:
        tokens_list = [word[0] for word in sentence]
        tokens = nlp.tokenizer.tokens_from_list(tokens_list)
        nlp.tagger(tokens)
        for idx, token in enumerate(tokens):
            pred = 0
            if token.tag_.startswith("V"):
                pred = 1
            preds.append(pred)
            label = sentence[idx][1]
            golds.append(label)