示例#1
0
def extract_triples(predicate_resume, disc_model_name=None):
    date_time = strftime("%Y-%m-%d_%H_%M_%S", gmtime())
    session = SnorkelSession()
    if disc_model_name is None:
        disc_model_name = "D" + predicate_resume["predicate_name"] + "Latest"
    test_cands_query = get_test_cids_with_span(predicate_resume, session)

    test_cands = test_cands_query.all()
    lstm = reRNN()
    logging.info("Loading marginals ")
    lstm.load(disc_model_name)

    predictions = lstm.predictions(test_cands)
    dump_file_path3 = "./results/" + "triples_" + predicate_resume[
        "predicate_name"] + date_time + ".csv"

    subject_type = predicate_resume["subject_type"]
    object_type = predicate_resume["object_type"]
    subject_type_split = subject_type.split('/')
    object_type_split = object_type.split('/')
    subject_type_end = subject_type_split[len(subject_type_split) - 1]
    object_type_end = object_type_split[len(object_type_split) - 1]
    with open(dump_file_path3, 'w+b') as f:
        writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(["text", "marginal", "prediction"])
        i = 0
        for c in test_cands:
            if predictions[i] == 1:
                subject_span = getattr(c, "subject").get_span()
                object_span = getattr(c, "object").get_span()
                subject_uri = get_dbpedia_node(subject_span, subject_type_end)
                object_uri = get_dbpedia_node(object_span, object_type_end)
                predicate_uri = predicate_resume["predicate_URI"]
                if subject_uri is not None and object_uri is not None:
                    row = [
                        str(subject_uri),
                        str(predicate_uri),
                        str(object_uri)
                    ]
                    writer.writerow(row)
            i = i + 1
def train_disc_model(predicate_resume, parallelism=8):
    logging.info("Start training disc ")
    session = SnorkelSession()
    train_cids_query = get_train_cids_with_marginals_and_span(predicate_resume, session)
    logging.info("Loading marginals ")
    train_marginals = load_marginals(session, split=0, cids_query=train_cids_query)

    train_kwargs = {
        'lr':         0.01,
        'dim':        50,
        'n_epochs':   10,
        'dropout':    0.25,
        'print_freq': 1,
        'max_sentence_length': 100
    }

    logging.info("Querying train cands")
    candidate_subclass=predicate_resume["candidate_subclass"]
    train_cands = session.query(candidate_subclass).filter(candidate_subclass.split == 0).order_by(candidate_subclass.id).all()#get_train_cands_with_marginals_and_span(predicate_resume, session).all()
    logging.info("Querying dev cands")
    dev_cands = get_dev_cands_with_span(predicate_resume, session).all()
    logging.info("Querying gold labels")
    L_gold_dev = get_gold_dev_matrix(predicate_resume, session)
    logging.info("Training")
    lstm = reRNN(seed=1701, n_threads=int(parallelism))
    lstm.train(train_cands, train_marginals, **train_kwargs)
    logging.info("Saving")
    _save_model(predicate_resume, lstm)
    #test model
    candidate_subclass=predicate_resume["candidate_subclass"]
    test_cands  = session.query(candidate_subclass).filter(candidate_subclass.split == 2).order_by(candidate_subclass.id).all()
    L_gold_test = get_gold_test_matrix(predicate_resume,session)
    p, r, f1 = lstm.score(test_cands, L_gold_test)
    print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))
    logging.info("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))
    lstm.save_marginals(session, test_cands)
示例#3
0
def score_disc_model(predicate_resume,
                     L_gold_test,
                     session,
                     date_time,
                     disc_model_name=None):

    if disc_model_name is None:
        disc_model_name = "D" + predicate_resume["predicate_name"] + "Latest"
    candidate_subclass = predicate_resume["candidate_subclass"]
    test_cands_query = session.query(candidate_subclass).filter(
        candidate_subclass.split == 2).order_by(candidate_subclass.id)

    test_cands = test_cands_query.all()
    lstm = reRNN()
    logging.info("Loading marginals ")
    lstm.load(disc_model_name)
    #lstm.save_marginals(session, test_cands)

    p, r, f1 = lstm.score(test_cands, L_gold_test)
    print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))
    logging.info("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(
        p, r, f1))
    dump_file_path1 = "./results/" + "test_disc_1_" + predicate_resume[
        "predicate_name"] + date_time + ".csv"
    with open(dump_file_path1, 'w+b') as f:
        writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(["Precision", "Recall", "F1"])
        writer.writerow(
            ["{0:.3f}".format(p), "{0:.3f}".format(r), "{0:.3f}".format(f1)])

    tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)
    logging.info("TP: {}, FP: {}, TN: {}, FN: {}".format(
        str(len(tp)), str(len(fp)), str(len(tn)), str(len(fn))))
    dump_file_path2 = "./results/" + "test_disc_2_" + predicate_resume[
        "predicate_name"] + date_time + ".csv"
    with open(dump_file_path2, 'w+b') as f:
        writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(["TP", "FP", "TN", "FN"])
        writer.writerow(
            [str(len(tp)),
             str(len(fp)),
             str(len(tn)),
             str(len(fn))])

    predictions = lstm.predictions(test_cands)
    marginals = lstm.marginals(test_cands)
    dump_file_path3 = "./results/" + "test_disc_3_" + predicate_resume[
        "predicate_name"] + date_time + ".csv"
    with open(dump_file_path3, 'w+b') as f:
        writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(["text", "marginal", "prediction"])
        i = 0
        for candidate in test_cands:
            start = candidate.subject.char_start
            end = candidate.object.char_end + 1
            if candidate.object.char_start < candidate.subject.char_start:
                start = candidate.object.char_start
                end = candidate.subject.char_end + 1
            text = "\"" + candidate.get_parent().text[start:end].encode(
                'ascii', 'ignore') + "\""
            row = [text, str(marginals[i]), str(predictions[i])]
            writer.writerow(row)
            i = i + 1

    dump_file_path4 = "./results/" + "triples_" + predicate_resume[
        "predicate_name"] + date_time + ".csv"

    subject_type = predicate_resume["subject_type"]
    object_type = predicate_resume["object_type"]
    subject_type_split = subject_type.split('/')
    object_type_split = object_type.split('/')
    subject_type_end = subject_type_split[len(subject_type_split) - 1]
    object_type_end = object_type_split[len(object_type_split) - 1]
    with open(dump_file_path4, 'w+b') as f:
        writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(["text", "marginal", "prediction"])
        i = 0
        for c in test_cands:
            if predictions[i] == 1:
                subject_span = getattr(c, "subject").get_span()
                object_span = getattr(c, "object").get_span()
                subject_uri = get_dbpedia_node(subject_span, subject_type_end)
                object_uri = get_dbpedia_node(object_span, object_type_end)
                predicate_uri = predicate_resume["predicate_URI"]
                if subject_uri is not None and object_uri is not None:
                    row = [
                        str(subject_uri),
                        str(predicate_uri),
                        str(object_uri)
                    ]
                    writer.writerow(row)
            i = i + 1
print len(dev_cands)

# In[9]:

from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr': 0.01,
    'dim': 50,
    'n_epochs': 10,
    'dropout': 0.25,
    'print_freq': 1,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=None)
lstm.train(train_cands,
           train_marginals,
           X_dev=dev_cands,
           Y_dev=L_gold_dev,
           **train_kwargs)

# In[10]:

p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

# In[11]:

tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)
示例#5
0
    X_ends = np.array(map(lambda x: x[0], X_ends))
    train_marginals = np.loadtxt(args["train_marginals"])
    Y_dev = np.loadtxt(args["dev_labels"])
    word_dict = read_word_dict(args["train_word_dict"])

    train_marginals[train_marginals < 0] = 0
    train_kwargs = {
        'lr': 0.001,
        'dim': 100,
        'n_epochs': 10,
        'dropout': 0.5,
        'print_freq': 1,
        'max_sentence_length': 2000,
    }

    lstm = reRNN(seed=100, n_threads=20)
    lstm.word_dict = SymbolTable()
    lstm.word_dict.d = word_dict
    lstm.word_dict.s = max_val

    np.random.seed(200)
    training_size = int(len(X_train) * float(sys.argv[1]))
    train_idx = np.random.randint(0, len(X_train), size=training_size)
    lstm.train(X_train[train_idx],
               X_ends[train_idx],
               train_marginals[train_idx],
               X_dev=X_dev,
               Y_dev=Y_dev,
               save_dir='{}'.format(args["save_dir"]),
               **train_kwargs)
def run(candidate1, candidate2, pairing_name, cand1_ngrams, cand2_ngrams,
        cand1Matcher, cand2Matcher, model_name, output_file_name,
        corpus_parser):
    print "Started"
    session = SnorkelSession()

    # The following line is for testing only. Feel free to ignore it.

    candidate_pair = candidate_subclass(pairing_name, [candidate1, candidate2])

    sentences = set()
    docs = session.query(Document).order_by(Document.name).all()
    for doc in docs:
        for s in doc.sentences:
            sentences.add(s)

    cand_1_ngrams = Ngrams(n_max=cand1_ngrams)
    # condition_ngrams = Ngrams(n_max=7)
    cand_2_ngrams = Ngrams(n_max=cand2_ngrams)
    # medium_ngrams = Ngrams(n_max=5)
    # type_ngrams = Ngrams(n_max=5)  # <--- Q: should we cut these down?
    # # level_ngrams = Ngrams(n_max=1)
    # unit_ngrams = Ngrams(n_max=1)

    # Construct our Matchers

    # cMatcher = matchers.getConditionMatcher()
    # mMatcher = matchers.getMediumMatcher()
    # tMatcher = matchers.getTypeMatcher()
    # lMatcher = matchers.getLevelMatcher()
    # uMatcher = matchers.getUnitMatcher()

    # Building the CandidateExtractors
    # candidate_extractor_BC = CandidateExtractor(BiomarkerCondition, [biomarker_ngrams, condition_ngrams], [bMatcher, cMatcher])
    candidate_extractor = CandidateExtractor(candidate_pair,
                                             [cand_1_ngrams, cand_2_ngrams],
                                             [cand1Matcher, cand2Matcher])
    # candidate_extractor_BM = CandidateExtractor(BiomarkerMedium, [biomarker_ngrams, medium_ngrams], [bMatcher, mMatcher])
    # candidate_extractor_BT = CandidateExtractor(BiomarkerType, [biomarker_ngrams, type_ngrams], [bMatcher, tMatcher])
    # candidate_extractor_BLU = CandidateExtractor(BiomarkerLevelUnit, [biomarker_ngrams, level_ngrams, unit_ngrams], [bMatcher, lMatcher, uMatcher])

    # List of Candidate Sets for each relation type: [train, dev, test]
    candidate_extractor.apply(sentences, split=4, clear=True)
    cands = session.query(candidate_pair).filter(
        candidate_pair.split == 4).order_by(candidate_pair.id).all()
    session.commit()
    # cands_BD = grabCandidates(candidate_extractor_BD, BiomarkerDrug)
    # cands_BM = grabCandidates(candidate_extractor_BM, BiomarkerMedium)
    # cands_BT = grabCandidates(candidate_extractor_BT, BiomarkerType)
    # cands_BLU = grabCandidates(candidate_extractor_BLU, BiomarkerLevelUnit)

    if (len(cands)) == 0:
        print "No Candidates Found"
        return
    if (pairing_name == 'BiomarkerCondition'):
        # session.rollback()
        # print "Number of dev BC candidates without adj. boosting: ", len(cands_BC[1])
        add_adj_candidate_BC(session, candidate_pair, cands, 4)
        # fix_specificity(session, BiomarkerCondition, cands_BC[1])
        # print "Number of dev BC candidates with adj. boosting: ", session.query(BiomarkerCondition).filter(BiomarkerCondition.split == 4).count()
        session.commit()

    lstm = reRNN(seed=1701, n_threads=None)

    lstm.load(model_name)

    predictions = lstm.predictions(cands)
    output_file = open(output_file_name, 'wb')
    import csv
    csvWriter = csv.writer(output_file)
    csvWriter.writerow(
        ['doc_id', 'sentence', candidate1, candidate2, 'prediction'])
    for i in range(len(cands)):
        doc_string = 'PMC' + str(cands[i].get_parent().get_parent())[9:]
        sentence_string = cands[i].get_parent().text
        cand_1_string = cands[i].get_contexts()[0].get_span()
        cand_2_string = cands[i].get_contexts()[1].get_span()
        prediction = predictions[i]
        csvWriter.writerow([
            unidecode(doc_string),
            unidecode(sentence_string),
            unidecode(cand_1_string),
            unidecode(cand_2_string), prediction
        ])