def extract_triples(predicate_resume, disc_model_name=None): date_time = strftime("%Y-%m-%d_%H_%M_%S", gmtime()) session = SnorkelSession() if disc_model_name is None: disc_model_name = "D" + predicate_resume["predicate_name"] + "Latest" test_cands_query = get_test_cids_with_span(predicate_resume, session) test_cands = test_cands_query.all() lstm = reRNN() logging.info("Loading marginals ") lstm.load(disc_model_name) predictions = lstm.predictions(test_cands) dump_file_path3 = "./results/" + "triples_" + predicate_resume[ "predicate_name"] + date_time + ".csv" subject_type = predicate_resume["subject_type"] object_type = predicate_resume["object_type"] subject_type_split = subject_type.split('/') object_type_split = object_type.split('/') subject_type_end = subject_type_split[len(subject_type_split) - 1] object_type_end = object_type_split[len(object_type_split) - 1] with open(dump_file_path3, 'w+b') as f: writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL) writer.writerow(["text", "marginal", "prediction"]) i = 0 for c in test_cands: if predictions[i] == 1: subject_span = getattr(c, "subject").get_span() object_span = getattr(c, "object").get_span() subject_uri = get_dbpedia_node(subject_span, subject_type_end) object_uri = get_dbpedia_node(object_span, object_type_end) predicate_uri = predicate_resume["predicate_URI"] if subject_uri is not None and object_uri is not None: row = [ str(subject_uri), str(predicate_uri), str(object_uri) ] writer.writerow(row) i = i + 1
def train_disc_model(predicate_resume, parallelism=8): logging.info("Start training disc ") session = SnorkelSession() train_cids_query = get_train_cids_with_marginals_and_span(predicate_resume, session) logging.info("Loading marginals ") train_marginals = load_marginals(session, split=0, cids_query=train_cids_query) train_kwargs = { 'lr': 0.01, 'dim': 50, 'n_epochs': 10, 'dropout': 0.25, 'print_freq': 1, 'max_sentence_length': 100 } logging.info("Querying train cands") candidate_subclass=predicate_resume["candidate_subclass"] train_cands = session.query(candidate_subclass).filter(candidate_subclass.split == 0).order_by(candidate_subclass.id).all()#get_train_cands_with_marginals_and_span(predicate_resume, session).all() logging.info("Querying dev cands") dev_cands = get_dev_cands_with_span(predicate_resume, session).all() logging.info("Querying gold labels") L_gold_dev = get_gold_dev_matrix(predicate_resume, session) logging.info("Training") lstm = reRNN(seed=1701, n_threads=int(parallelism)) lstm.train(train_cands, train_marginals, **train_kwargs) logging.info("Saving") _save_model(predicate_resume, lstm) #test model candidate_subclass=predicate_resume["candidate_subclass"] test_cands = session.query(candidate_subclass).filter(candidate_subclass.split == 2).order_by(candidate_subclass.id).all() L_gold_test = get_gold_test_matrix(predicate_resume,session) p, r, f1 = lstm.score(test_cands, L_gold_test) print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1)) logging.info("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1)) lstm.save_marginals(session, test_cands)
def score_disc_model(predicate_resume, L_gold_test, session, date_time, disc_model_name=None): if disc_model_name is None: disc_model_name = "D" + predicate_resume["predicate_name"] + "Latest" candidate_subclass = predicate_resume["candidate_subclass"] test_cands_query = session.query(candidate_subclass).filter( candidate_subclass.split == 2).order_by(candidate_subclass.id) test_cands = test_cands_query.all() lstm = reRNN() logging.info("Loading marginals ") lstm.load(disc_model_name) #lstm.save_marginals(session, test_cands) p, r, f1 = lstm.score(test_cands, L_gold_test) print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1)) logging.info("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format( p, r, f1)) dump_file_path1 = "./results/" + "test_disc_1_" + predicate_resume[ "predicate_name"] + date_time + ".csv" with open(dump_file_path1, 'w+b') as f: writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL) writer.writerow(["Precision", "Recall", "F1"]) writer.writerow( ["{0:.3f}".format(p), "{0:.3f}".format(r), "{0:.3f}".format(f1)]) tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test) logging.info("TP: {}, FP: {}, TN: {}, FN: {}".format( str(len(tp)), str(len(fp)), str(len(tn)), str(len(fn)))) dump_file_path2 = "./results/" + "test_disc_2_" + predicate_resume[ "predicate_name"] + date_time + ".csv" with open(dump_file_path2, 'w+b') as f: writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL) writer.writerow(["TP", "FP", "TN", "FN"]) writer.writerow( [str(len(tp)), str(len(fp)), str(len(tn)), str(len(fn))]) predictions = lstm.predictions(test_cands) marginals = lstm.marginals(test_cands) dump_file_path3 = "./results/" + "test_disc_3_" + predicate_resume[ "predicate_name"] + date_time + ".csv" with open(dump_file_path3, 'w+b') as f: writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL) writer.writerow(["text", "marginal", "prediction"]) i = 0 for candidate in test_cands: start = candidate.subject.char_start end = candidate.object.char_end + 1 if candidate.object.char_start < candidate.subject.char_start: start = candidate.object.char_start end = candidate.subject.char_end + 1 text = "\"" + candidate.get_parent().text[start:end].encode( 'ascii', 'ignore') + "\"" row = [text, str(marginals[i]), str(predictions[i])] writer.writerow(row) i = i + 1 dump_file_path4 = "./results/" + "triples_" + predicate_resume[ "predicate_name"] + date_time + ".csv" subject_type = predicate_resume["subject_type"] object_type = predicate_resume["object_type"] subject_type_split = subject_type.split('/') object_type_split = object_type.split('/') subject_type_end = subject_type_split[len(subject_type_split) - 1] object_type_end = object_type_split[len(object_type_split) - 1] with open(dump_file_path4, 'w+b') as f: writer = csv.writer(f, delimiter=',', quoting=csv.QUOTE_MINIMAL) writer.writerow(["text", "marginal", "prediction"]) i = 0 for c in test_cands: if predictions[i] == 1: subject_span = getattr(c, "subject").get_span() object_span = getattr(c, "object").get_span() subject_uri = get_dbpedia_node(subject_span, subject_type_end) object_uri = get_dbpedia_node(object_span, object_type_end) predicate_uri = predicate_resume["predicate_URI"] if subject_uri is not None and object_uri is not None: row = [ str(subject_uri), str(predicate_uri), str(object_uri) ] writer.writerow(row) i = i + 1
print len(dev_cands) # In[9]: from snorkel.learning.disc_models.rnn import reRNN train_kwargs = { 'lr': 0.01, 'dim': 50, 'n_epochs': 10, 'dropout': 0.25, 'print_freq': 1, 'max_sentence_length': 100 } lstm = reRNN(seed=1701, n_threads=None) lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs) # In[10]: p, r, f1 = lstm.score(test_cands, L_gold_test) print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1)) # In[11]: tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)
X_ends = np.array(map(lambda x: x[0], X_ends)) train_marginals = np.loadtxt(args["train_marginals"]) Y_dev = np.loadtxt(args["dev_labels"]) word_dict = read_word_dict(args["train_word_dict"]) train_marginals[train_marginals < 0] = 0 train_kwargs = { 'lr': 0.001, 'dim': 100, 'n_epochs': 10, 'dropout': 0.5, 'print_freq': 1, 'max_sentence_length': 2000, } lstm = reRNN(seed=100, n_threads=20) lstm.word_dict = SymbolTable() lstm.word_dict.d = word_dict lstm.word_dict.s = max_val np.random.seed(200) training_size = int(len(X_train) * float(sys.argv[1])) train_idx = np.random.randint(0, len(X_train), size=training_size) lstm.train(X_train[train_idx], X_ends[train_idx], train_marginals[train_idx], X_dev=X_dev, Y_dev=Y_dev, save_dir='{}'.format(args["save_dir"]), **train_kwargs)
def run(candidate1, candidate2, pairing_name, cand1_ngrams, cand2_ngrams, cand1Matcher, cand2Matcher, model_name, output_file_name, corpus_parser): print "Started" session = SnorkelSession() # The following line is for testing only. Feel free to ignore it. candidate_pair = candidate_subclass(pairing_name, [candidate1, candidate2]) sentences = set() docs = session.query(Document).order_by(Document.name).all() for doc in docs: for s in doc.sentences: sentences.add(s) cand_1_ngrams = Ngrams(n_max=cand1_ngrams) # condition_ngrams = Ngrams(n_max=7) cand_2_ngrams = Ngrams(n_max=cand2_ngrams) # medium_ngrams = Ngrams(n_max=5) # type_ngrams = Ngrams(n_max=5) # <--- Q: should we cut these down? # # level_ngrams = Ngrams(n_max=1) # unit_ngrams = Ngrams(n_max=1) # Construct our Matchers # cMatcher = matchers.getConditionMatcher() # mMatcher = matchers.getMediumMatcher() # tMatcher = matchers.getTypeMatcher() # lMatcher = matchers.getLevelMatcher() # uMatcher = matchers.getUnitMatcher() # Building the CandidateExtractors # candidate_extractor_BC = CandidateExtractor(BiomarkerCondition, [biomarker_ngrams, condition_ngrams], [bMatcher, cMatcher]) candidate_extractor = CandidateExtractor(candidate_pair, [cand_1_ngrams, cand_2_ngrams], [cand1Matcher, cand2Matcher]) # candidate_extractor_BM = CandidateExtractor(BiomarkerMedium, [biomarker_ngrams, medium_ngrams], [bMatcher, mMatcher]) # candidate_extractor_BT = CandidateExtractor(BiomarkerType, [biomarker_ngrams, type_ngrams], [bMatcher, tMatcher]) # candidate_extractor_BLU = CandidateExtractor(BiomarkerLevelUnit, [biomarker_ngrams, level_ngrams, unit_ngrams], [bMatcher, lMatcher, uMatcher]) # List of Candidate Sets for each relation type: [train, dev, test] candidate_extractor.apply(sentences, split=4, clear=True) cands = session.query(candidate_pair).filter( candidate_pair.split == 4).order_by(candidate_pair.id).all() session.commit() # cands_BD = grabCandidates(candidate_extractor_BD, BiomarkerDrug) # cands_BM = grabCandidates(candidate_extractor_BM, BiomarkerMedium) # cands_BT = grabCandidates(candidate_extractor_BT, BiomarkerType) # cands_BLU = grabCandidates(candidate_extractor_BLU, BiomarkerLevelUnit) if (len(cands)) == 0: print "No Candidates Found" return if (pairing_name == 'BiomarkerCondition'): # session.rollback() # print "Number of dev BC candidates without adj. boosting: ", len(cands_BC[1]) add_adj_candidate_BC(session, candidate_pair, cands, 4) # fix_specificity(session, BiomarkerCondition, cands_BC[1]) # print "Number of dev BC candidates with adj. boosting: ", session.query(BiomarkerCondition).filter(BiomarkerCondition.split == 4).count() session.commit() lstm = reRNN(seed=1701, n_threads=None) lstm.load(model_name) predictions = lstm.predictions(cands) output_file = open(output_file_name, 'wb') import csv csvWriter = csv.writer(output_file) csvWriter.writerow( ['doc_id', 'sentence', candidate1, candidate2, 'prediction']) for i in range(len(cands)): doc_string = 'PMC' + str(cands[i].get_parent().get_parent())[9:] sentence_string = cands[i].get_parent().text cand_1_string = cands[i].get_contexts()[0].get_span() cand_2_string = cands[i].get_contexts()[1].get_span() prediction = predictions[i] csvWriter.writerow([ unidecode(doc_string), unidecode(sentence_string), unidecode(cand_1_string), unidecode(cand_2_string), prediction ])