def test_sql_entity_replacer(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer preprocessor = get_logical_form_preprocessor('geo', 'sql') grammar = get_grammar('geo', 'sql') sql = preprocessor( 'SELECT CITYalias0.POPULATION FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "erie" AND CITYalias0.STATE_NAME = "pennsylvania" ;') # Test Replace sql_tokenizer = get_logical_form_tokenizer('geo', 'sql') sql_tokens = sql_tokenizer.tokenize(sql) question = 'what is the population of erie pennsylvania' question_tokenizer = WordTokenizer(SpacyWordSplitter()) question_tokens = question_tokenizer.tokenize(question) from geo_gnn_entity_matcher import GeoGNNEntityMatcher base_path = os.path.join('../../', 'data', 'geo') entity_path = os.path.join(base_path, 'geo_entities.json') matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6) candidates = matcher.match(question_tokens) for can_idx, can in enumerate(candidates): can['index'] = can_idx is_valid, replaced_tokens = replace_sql_entity(grammar, sql, sql_tokens, candidates) print(sql_tokens) print("Is Valid: ", is_valid) print(replaced_tokens)
def test_lambda_calculus_entity_replacer(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer preprocessor = get_logical_form_preprocessor('geo', 'lambda') grammar = get_grammar('geo', 'lambda') lc = preprocessor( "(lambda $0:e (and:<t*,t> (major:<lo,t> $0) (city:<c,t> $0) (loc:<lo,<lo,t>> $0 alaska:s)))", ) # Test Replace lc_tokenizer = get_logical_form_tokenizer('geo', 'lambda') lc_tokens = lc_tokenizer.tokenize(lc) question = 'what are the major cities in alaska' question_tokenizer = WordTokenizer(SpacyWordSplitter()) question_tokens = question_tokenizer.tokenize(question) from geo_gnn_entity_matcher import GeoLambdaCalculusGNNEntityMatcher base_path = os.path.join('../../', 'data', 'geo') entity_path = os.path.join(base_path, 'geo_entities.json') matcher = GeoLambdaCalculusGNNEntityMatcher(entity_path, max_ngram=6) candidates = matcher.match(question_tokens) for can_idx, can in enumerate(candidates): can['index'] = can_idx is_valid, replaced_tokens = replace_lambda_calculus_entity(grammar, lc, lc_tokens, candidates) print(lc_tokens) print(is_valid) print(replaced_tokens)
def test_prolog_entity_replacer(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer preprocessor = get_logical_form_preprocessor('geo', 'prolog', normalize_var_with_de_brujin_index=True) grammar = get_grammar('geo', 'prolog') prolog = preprocessor( "answer(A,(capital(A),loc(A,B),state(B),loc(C,B),city(C),const(C,cityid(durham,_))))", ).lower() # Test Replace prolog_tokenizer = get_logical_form_tokenizer('geo', 'prolog') prolog_tokens = prolog_tokenizer.tokenize(prolog) question = 'what is the capital of states that have cities named durham ?' question_tokenizer = WordTokenizer(SpacyWordSplitter()) question_tokens = question_tokenizer.tokenize(question) from geo_gnn_entity_matcher import GeoGNNEntityMatcher base_path = os.path.join('../../', 'data', 'geo') entity_path = os.path.join(base_path, 'geo_entities.json') matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6) candidates = matcher.match(question_tokens) for can_idx, can in enumerate(candidates): can['index'] = can_idx is_valid, replaced_tokens = replace_funql_entity(grammar, prolog, prolog_tokens, candidates) print(prolog_tokens) print(is_valid) print(replaced_tokens)
def test_funql_entity_replacer(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer preprocessor = get_logical_form_preprocessor('geo', 'funql') grammar = get_grammar('geo', 'funql') funql = preprocessor( "answer(count(intersection(state(loc_2(countryid('usa'))), traverse_1(shortest(river(all))))))") # Test Replace funql = preprocessor("answer(len(longest(river(loc_2(stateid('california'))))))") funql_tokenizer = get_logical_form_tokenizer('geo', 'funql') funql_tokens = funql_tokenizer.tokenize(funql) question = 'how long is the longest river in california' question_tokenizer = WordTokenizer(SpacyWordSplitter()) question_tokens = question_tokenizer.tokenize(question) from geo_gnn_entity_matcher import GeoGNNEntityMatcher base_path = os.path.join('../../', 'data', 'geo') entity_path = os.path.join(base_path, 'geo_entities.json') matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6) candidates = matcher.match(question_tokens) for can_idx, can in enumerate(candidates): can['index'] = can_idx is_valid, replaced_tokens = replace_funql_entity(grammar, funql, funql_tokens, candidates) print(funql_tokens) print(is_valid) print(replaced_tokens)
def test_prolog_entity_extractor(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor('geo', 'prolog') grammar = get_grammar('geo', 'prolog') prolog = preprocessor("answer(A,(capital(A),loc(A,B),state(B),loc(C,B),city(C),const(C,cityid(durham,_))))") entities = funql_entity_extractor(grammar, prolog) print(entities)
def test_sql_entity_extractor(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor('geo', 'sql') grammar = get_grammar('geo', 'sql') print(grammar.copy_terminal_set) sql = preprocessor('SELECT CITYalias0.POPULATION FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "erie" AND CITYalias0.STATE_NAME = "pennsylvania" ;') entities = sql_entity_extractor(grammar, sql) print(entities)
def test_funql_entity_extractor(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor('geo', 'funql') grammar = get_grammar('geo', 'funql') funql = preprocessor( "answer(count(intersection(state(loc_2(countryid('usa'))), traverse_1(shortest(river(all))))))") entities = funql_entity_extractor(grammar, funql) print(entities)
def test_funql_entity_extractor(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor('atis', 'funql') grammar = get_grammar('atis', 'funql') lc = preprocessor( "answer(intersection(_meal_2(meal_description(snack)),_airline_2(airline_code(ff))))" ) entities = funql_entity_extractor(grammar, lc) print(entities)
def test_lambda_calculus_entity_extractor(): import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor('atis', 'lambda') grammar = get_grammar('atis', 'lambda') lc = preprocessor( "( _lambda $v1 e ( _exists $v0 ( _and ( _flight $v0 ) ( _from $v0 washington:_ci ) ( _to $v0 toronto:_ci ) ( _day $v0 saturday:_da ) ( _= ( _fare $v0 ) $v1 ) ) ) )" ) entities = lambda_calculus_entity_extractor(grammar, lc) print(entities)
with open('prolog_result.log', 'r') as f: for line in f: line = line.strip() if len(line) > 0: match = pattern.match(line) if match: valid_executions += 1 index, is_correct = int( match.group(1)), match.group(2) == 'y' predictions[index]['execution_correct'] = is_correct if is_correct: correct += 1 print("Total: %d, Grammar Valid: %d, Valid Executions: %d, Correct: %d, Accuracy: %f" % (len(predictions), grammar_valid_count, valid_executions, correct, float(correct / len(predictions)))) # with open(path, 'w') as f: # f.write(json.dumps(predictions)) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--predictions', help='file that stores the prediction results', required=True) parser.add_argument("--recover_variable", action='store_true', default=False) parser.add_argument("--timeout", help='timeout limit for expression', default=120, type=int) args = parser.parse_args() from grammars.grammar import get_grammar grammar = get_grammar('geo', 'prolog') evaluate(args.predictions, grammar, args.recover_variable, args.timeout)
def test_entity_linking(): base_path = os.path.join('../../', 'data', 'geo') entity_path = os.path.join(base_path, 'geo_entities.json') matcher = GeoSQLGNNEntityMatcher(entity_path, max_ngram=6) toknerizer = WordTokenizer(SpacyWordSplitter()) import sys sys.path += ['../../'] from grammars.grammar import get_grammar from grammars.utils import get_logical_form_preprocessor preprocessor = get_logical_form_preprocessor( 'geo', 'sql', normalize_var_with_de_brujin_index=False) grammar = get_grammar('geo', 'sql') train_data = os.path.join(base_path, 'geo_sql_question_based_train.tsv') empty_count = 0 max_number_of_candidates = 0 numbers = list() invalid_count = 0 with open(train_data, 'r') as f: for lidx, line in enumerate(f): line = line.strip() sentence, funql = line.split('\t') tokens = toknerizer.tokenize(sentence) candidates = matcher.match(tokens) if len(candidates) > max_number_of_candidates: max_number_of_candidates = len(candidates) has_duplicate_entity = False for cidx1, can1 in enumerate(candidates): for cidx2, can2 in enumerate(candidates): if cidx1 == cidx2: continue if can1['value'] == can2['value'] and can1['type'] == can2[ 'type']: has_duplicate_entity = True break if has_duplicate_entity: break if len(candidates) == 0: empty_count += 1 numbers.append(len(candidates)) # Validate processed_funql = preprocessor(funql).lower() golden_entities = sql_entity_extractor(grammar, processed_funql) valid = True for ge in golden_entities: for candidate in candidates: compare_value = candidate['value'] if 'formatted_value' not in candidate \ else candidate['formatted_value'] if compare_value == ge or candidate.get( 'abbreviation', "") == ge: break else: valid = False if not valid: invalid_count += 1 print(lidx) print(sentence) print(funql) print("Number of Candidates: ", len(candidates)) print("Has Duplicate Candidates: ", has_duplicate_entity) print(candidates) print(golden_entities) print("Is Valid: ", valid) print('===\n\n') print("Largest number of candidates: ", max_number_of_candidates) print("Number of empty candidates: ", empty_count) print("Averaged candidates: ", np.mean(numbers)) print("Invalid Count: ", invalid_count)
def build_data_reader(FLAGS): splitter = SpacyWordSplitter(pos_tags=True) question_tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True)) reader = None if FLAGS.model == 'parsing': # Parsing grammar = get_grammar(FLAGS.task, FLAGS.language) assert grammar is not None logical_form_preprocessor = get_logical_form_preprocessor( FLAGS.task, FLAGS.language) if FLAGS.do_train: max_target_length = FLAGS.max_decode_length else: max_target_length = 0 reader = GrammarBasedDataReader( question_tokenizer, grammar, logical_form_preprocessor=logical_form_preprocessor, maximum_target_length=max_target_length) elif FLAGS.model in ['copy_parsing', 'copy_parsing_2']: # Parsing grammar = get_grammar(FLAGS.task, FLAGS.language) assert grammar is not None logical_form_preprocessor = get_logical_form_preprocessor( FLAGS.task, FLAGS.language) if FLAGS.do_train: max_target_length = FLAGS.max_decode_length else: max_target_length = 0 entity_matcher = get_entity_matcher(FLAGS.task, FLAGS.language) utterance_preprocessor = get_utterance_preprocessor( FLAGS.task, FLAGS.language) reader = GrammarCopyBasedDataReader( question_tokenizer, grammar, logical_form_preprocessor=logical_form_preprocessor, utterance_preprocessor=utterance_preprocessor, copy_link_finder=entity_matcher, maximum_target_length=max_target_length) elif FLAGS.model == 'translation': # Translation logical_form_tokenizer = get_logical_form_tokenizer( FLAGS.task, FLAGS.language) reader = Seq2SeqDataReader( question_tokenizer=question_tokenizer, logical_form_tokenizer=logical_form_tokenizer, is_parsing=False) return reader elif FLAGS.model == 'seq_parsing': # Parsing without grammar logical_form_tokenizer = get_logical_form_tokenizer( FLAGS.task, FLAGS.language) reader = Seq2SeqDataReader( question_tokenizer=question_tokenizer, logical_form_tokenizer=logical_form_tokenizer, is_parsing=True) elif FLAGS.model == 'recombination_seq_parsing': logical_form_preprocessor = get_logical_form_preprocessor( FLAGS.task, FLAGS.language, normalize_var_with_de_brujin_index=True) logical_form_tokenizer = get_logical_form_tokenizer( FLAGS.task, FLAGS.language) if FLAGS.do_train: max_target_length = FLAGS.max_decode_length else: max_target_length = 0 reader = Seq2SeqDataReader( question_tokenizer=question_tokenizer, logical_form_tokenizer=logical_form_tokenizer, logical_form_preprocessor=logical_form_preprocessor, is_parsing=True, maximum_target_length=max_target_length) return reader elif FLAGS.model == 'recombination_copy_seq_parsing': logical_form_preprocessor = get_logical_form_preprocessor( FLAGS.task, FLAGS.language, normalize_var_with_de_brujin_index=True) logical_form_tokenizer = get_logical_form_tokenizer( FLAGS.task, FLAGS.language) if FLAGS.do_train: max_target_length = FLAGS.max_decode_length else: max_target_length = 0 entity_matcher = get_seq2seq_entity_matcher(FLAGS.task, FLAGS.language) if FLAGS.language.startswith('sql'): exclude_target_words = [ 'select', 'from', 'and', 'in', 'where', 'group', 'order', 'having', 'limit', 'not' ] else: exclude_target_words = None reader = Seq2SeqDataReader( question_tokenizer=question_tokenizer, logical_form_tokenizer=logical_form_tokenizer, logical_form_preprocessor=logical_form_preprocessor, is_parsing=True, enable_copy=True, maximum_target_length=max_target_length, entity_matcher=entity_matcher, exclude_target_words=exclude_target_words) return reader elif FLAGS.model in ['gnn_parsing', 'gnn_parsing2']: logical_form_preprocessor = get_logical_form_preprocessor( FLAGS.task, FLAGS.language, normalize_var_with_de_brujin_index=True) logical_form_tokenizer = get_logical_form_tokenizer( FLAGS.task, FLAGS.language) if FLAGS.do_train: max_target_length = FLAGS.max_decode_length allow_drop = True else: max_target_length = 0 allow_drop = False grammar = get_grammar(FLAGS.task, FLAGS.language) entity_matcher = get_gnn_entity_matcher(FLAGS.task, FLAGS.language) entity_replacer = get_gnn_entity_replacer(FLAGS.task, FLAGS.language) reader = GNNCopyTransformerDataReader( entity_matcher=entity_matcher, entity_replacer=entity_replacer, target_grammar=grammar, source_tokenizer=question_tokenizer, target_tokenizer=logical_form_tokenizer, logical_form_preprocessor=logical_form_preprocessor, relative_position_clipped_range=FLAGS. gnn_relative_position_clipped_range, nlabels=FLAGS.gnn_transformer_num_edge_labels, allow_drop=allow_drop) return reader return reader
code.append( """ compare_result_t%d <- (timeout %d (return $! (%s)))""" % (idx, timeout_limits, tq)) code.append(""" print ("t", %d, compare_result_t%d)""" % (idx, idx)) code = "\n".join(code) code = script_template % (code) with open('Main.hs', 'w') as f: f.write(code) # copy file shutil.copyfile('./Main.hs', './lambda_calculus/evaluator/app/Main.hs') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--predictions', help='file that stores the prediction results', required=True) parser.add_argument("--timeout", help='timeout limit for expression', default=120, type=int) args = parser.parse_args() from grammars.grammar import get_grammar grammar = get_grammar('geo', 'lambda') from grammars.utils import get_logical_form_tokenizer tokenizer = get_logical_form_tokenizer('geo', 'lambda') evaluate(args.predictions, grammar, tokenizer, args.timeout) # calculuate_result(args.predictions)