Пример #1
0
def test_sql_entity_replacer():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer
    preprocessor = get_logical_form_preprocessor('geo', 'sql')
    grammar = get_grammar('geo', 'sql')
    sql = preprocessor(
        'SELECT CITYalias0.POPULATION FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "erie" AND CITYalias0.STATE_NAME = "pennsylvania" ;')

    # Test Replace
    sql_tokenizer = get_logical_form_tokenizer('geo', 'sql')
    sql_tokens = sql_tokenizer.tokenize(sql)

    question = 'what is the population of erie pennsylvania'
    question_tokenizer = WordTokenizer(SpacyWordSplitter())
    question_tokens = question_tokenizer.tokenize(question)

    from geo_gnn_entity_matcher import GeoGNNEntityMatcher
    base_path = os.path.join('../../', 'data', 'geo')
    entity_path = os.path.join(base_path, 'geo_entities.json')
    matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6)
    candidates = matcher.match(question_tokens)
    for can_idx, can in enumerate(candidates):
        can['index'] = can_idx

    is_valid, replaced_tokens = replace_sql_entity(grammar, sql, sql_tokens, candidates)
    print(sql_tokens)
    print("Is Valid: ", is_valid)
    print(replaced_tokens)
Пример #2
0
def test_lambda_calculus_entity_replacer():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer
    preprocessor = get_logical_form_preprocessor('geo', 'lambda')
    grammar = get_grammar('geo', 'lambda')
    lc = preprocessor(
        "(lambda $0:e (and:<t*,t> (major:<lo,t> $0) (city:<c,t> $0) (loc:<lo,<lo,t>> $0 alaska:s)))",
    )

    # Test Replace
    lc_tokenizer = get_logical_form_tokenizer('geo', 'lambda')
    lc_tokens = lc_tokenizer.tokenize(lc)

    question = 'what are the major cities in alaska'
    question_tokenizer = WordTokenizer(SpacyWordSplitter())
    question_tokens = question_tokenizer.tokenize(question)

    from geo_gnn_entity_matcher import GeoLambdaCalculusGNNEntityMatcher
    base_path = os.path.join('../../', 'data', 'geo')
    entity_path = os.path.join(base_path, 'geo_entities.json')
    matcher = GeoLambdaCalculusGNNEntityMatcher(entity_path, max_ngram=6)
    candidates = matcher.match(question_tokens)
    for can_idx, can in enumerate(candidates):
        can['index'] = can_idx

    is_valid, replaced_tokens = replace_lambda_calculus_entity(grammar, lc, lc_tokens, candidates)
    print(lc_tokens)
    print(is_valid)
    print(replaced_tokens)
Пример #3
0
def test_prolog_entity_replacer():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer
    preprocessor = get_logical_form_preprocessor('geo', 'prolog', normalize_var_with_de_brujin_index=True)
    grammar = get_grammar('geo', 'prolog')
    prolog = preprocessor(
        "answer(A,(capital(A),loc(A,B),state(B),loc(C,B),city(C),const(C,cityid(durham,_))))",
    ).lower()

    # Test Replace
    prolog_tokenizer = get_logical_form_tokenizer('geo', 'prolog')
    prolog_tokens = prolog_tokenizer.tokenize(prolog)

    question = 'what is the capital of states that have cities named durham ?'
    question_tokenizer = WordTokenizer(SpacyWordSplitter())
    question_tokens = question_tokenizer.tokenize(question)

    from geo_gnn_entity_matcher import GeoGNNEntityMatcher
    base_path = os.path.join('../../', 'data', 'geo')
    entity_path = os.path.join(base_path, 'geo_entities.json')
    matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6)
    candidates = matcher.match(question_tokens)
    for can_idx, can in enumerate(candidates):
        can['index'] = can_idx

    is_valid, replaced_tokens = replace_funql_entity(grammar, prolog, prolog_tokens, candidates)
    print(prolog_tokens)
    print(is_valid)
    print(replaced_tokens)
Пример #4
0
def test_funql_entity_replacer():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor, get_logical_form_tokenizer
    preprocessor = get_logical_form_preprocessor('geo', 'funql')
    grammar = get_grammar('geo', 'funql')
    funql = preprocessor(
        "answer(count(intersection(state(loc_2(countryid('usa'))), traverse_1(shortest(river(all))))))")
    # Test Replace
    funql = preprocessor("answer(len(longest(river(loc_2(stateid('california'))))))")
    funql_tokenizer = get_logical_form_tokenizer('geo', 'funql')
    funql_tokens = funql_tokenizer.tokenize(funql)

    question = 'how long is the longest river in california'
    question_tokenizer = WordTokenizer(SpacyWordSplitter())
    question_tokens = question_tokenizer.tokenize(question)

    from geo_gnn_entity_matcher import GeoGNNEntityMatcher
    base_path = os.path.join('../../', 'data', 'geo')
    entity_path = os.path.join(base_path, 'geo_entities.json')
    matcher = GeoGNNEntityMatcher(entity_path, max_ngram=6)
    candidates = matcher.match(question_tokens)
    for can_idx, can in enumerate(candidates):
        can['index'] = can_idx

    is_valid, replaced_tokens = replace_funql_entity(grammar, funql, funql_tokens, candidates)
    print(funql_tokens)
    print(is_valid)
    print(replaced_tokens)
Пример #5
0
def test_prolog_entity_extractor():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor('geo', 'prolog')
    grammar = get_grammar('geo', 'prolog')
    prolog = preprocessor("answer(A,(capital(A),loc(A,B),state(B),loc(C,B),city(C),const(C,cityid(durham,_))))")
    entities = funql_entity_extractor(grammar, prolog)
    print(entities)
Пример #6
0
def test_sql_entity_extractor():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor('geo', 'sql')
    grammar = get_grammar('geo', 'sql')
    print(grammar.copy_terminal_set)
    sql = preprocessor('SELECT CITYalias0.POPULATION FROM CITY AS CITYalias0 WHERE CITYalias0.CITY_NAME = "erie" AND CITYalias0.STATE_NAME = "pennsylvania" ;')
    entities = sql_entity_extractor(grammar, sql)
    print(entities)
Пример #7
0
def test_funql_entity_extractor():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor('geo', 'funql')
    grammar = get_grammar('geo', 'funql')
    funql = preprocessor(
        "answer(count(intersection(state(loc_2(countryid('usa'))), traverse_1(shortest(river(all))))))")
    entities = funql_entity_extractor(grammar, funql)
    print(entities)
Пример #8
0
def test_funql_entity_extractor():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor('atis', 'funql')
    grammar = get_grammar('atis', 'funql')
    lc = preprocessor(
        "answer(intersection(_meal_2(meal_description(snack)),_airline_2(airline_code(ff))))"
    )
    entities = funql_entity_extractor(grammar, lc)
    print(entities)
Пример #9
0
def test_lambda_calculus_entity_extractor():
    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor('atis', 'lambda')
    grammar = get_grammar('atis', 'lambda')
    lc = preprocessor(
        "( _lambda $v1 e ( _exists $v0 ( _and ( _flight $v0 ) ( _from $v0 washington:_ci ) ( _to $v0 toronto:_ci ) ( _day $v0 saturday:_da ) ( _= ( _fare $v0 ) $v1 ) ) ) )"
    )
    entities = lambda_calculus_entity_extractor(grammar, lc)
    print(entities)
Пример #10
0
    with open('prolog_result.log', 'r') as f:
        for line in f:
            line = line.strip()
            if len(line) > 0:
                match = pattern.match(line)
                if match:
                    valid_executions += 1
                    index, is_correct = int(
                        match.group(1)), match.group(2) == 'y'
                    predictions[index]['execution_correct'] = is_correct
                    if is_correct:
                        correct += 1
    print("Total: %d, Grammar Valid: %d, Valid Executions: %d, Correct: %d, Accuracy: %f" %
          (len(predictions), grammar_valid_count, valid_executions, correct, float(correct / len(predictions))))
    # with open(path, 'w') as f:
    #     f.write(json.dumps(predictions))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--predictions', help='file that stores the prediction results', required=True)
    parser.add_argument("--recover_variable",
                        action='store_true', default=False)
    parser.add_argument("--timeout",
                        help='timeout limit for expression', default=120, type=int)
    args = parser.parse_args()
    from grammars.grammar import get_grammar
    grammar = get_grammar('geo', 'prolog')
    evaluate(args.predictions, grammar, args.recover_variable, args.timeout)
Пример #11
0
def test_entity_linking():
    base_path = os.path.join('../../', 'data', 'geo')
    entity_path = os.path.join(base_path, 'geo_entities.json')

    matcher = GeoSQLGNNEntityMatcher(entity_path, max_ngram=6)
    toknerizer = WordTokenizer(SpacyWordSplitter())

    import sys
    sys.path += ['../../']
    from grammars.grammar import get_grammar
    from grammars.utils import get_logical_form_preprocessor
    preprocessor = get_logical_form_preprocessor(
        'geo', 'sql', normalize_var_with_de_brujin_index=False)
    grammar = get_grammar('geo', 'sql')

    train_data = os.path.join(base_path, 'geo_sql_question_based_train.tsv')
    empty_count = 0
    max_number_of_candidates = 0
    numbers = list()
    invalid_count = 0
    with open(train_data, 'r') as f:
        for lidx, line in enumerate(f):
            line = line.strip()
            sentence, funql = line.split('\t')
            tokens = toknerizer.tokenize(sentence)
            candidates = matcher.match(tokens)

            if len(candidates) > max_number_of_candidates:
                max_number_of_candidates = len(candidates)

            has_duplicate_entity = False
            for cidx1, can1 in enumerate(candidates):
                for cidx2, can2 in enumerate(candidates):
                    if cidx1 == cidx2:
                        continue
                    if can1['value'] == can2['value'] and can1['type'] == can2[
                            'type']:
                        has_duplicate_entity = True
                        break
                if has_duplicate_entity:
                    break

            if len(candidates) == 0:
                empty_count += 1
            numbers.append(len(candidates))

            # Validate
            processed_funql = preprocessor(funql).lower()
            golden_entities = sql_entity_extractor(grammar, processed_funql)

            valid = True
            for ge in golden_entities:
                for candidate in candidates:
                    compare_value = candidate['value'] if 'formatted_value' not in candidate \
                        else candidate['formatted_value']
                    if compare_value == ge or candidate.get(
                            'abbreviation', "") == ge:
                        break
                else:
                    valid = False
            if not valid:
                invalid_count += 1

            print(lidx)
            print(sentence)
            print(funql)
            print("Number of Candidates: ", len(candidates))
            print("Has Duplicate Candidates: ", has_duplicate_entity)
            print(candidates)
            print(golden_entities)
            print("Is Valid: ", valid)
            print('===\n\n')

    print("Largest number of candidates: ", max_number_of_candidates)
    print("Number of empty candidates: ", empty_count)
    print("Averaged candidates: ", np.mean(numbers))
    print("Invalid Count: ", invalid_count)
Пример #12
0
def build_data_reader(FLAGS):
    splitter = SpacyWordSplitter(pos_tags=True)
    question_tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
    reader = None
    if FLAGS.model == 'parsing':
        # Parsing
        grammar = get_grammar(FLAGS.task, FLAGS.language)
        assert grammar is not None
        logical_form_preprocessor = get_logical_form_preprocessor(
            FLAGS.task, FLAGS.language)
        if FLAGS.do_train:
            max_target_length = FLAGS.max_decode_length
        else:
            max_target_length = 0
        reader = GrammarBasedDataReader(
            question_tokenizer,
            grammar,
            logical_form_preprocessor=logical_form_preprocessor,
            maximum_target_length=max_target_length)
    elif FLAGS.model in ['copy_parsing', 'copy_parsing_2']:
        # Parsing
        grammar = get_grammar(FLAGS.task, FLAGS.language)
        assert grammar is not None
        logical_form_preprocessor = get_logical_form_preprocessor(
            FLAGS.task, FLAGS.language)
        if FLAGS.do_train:
            max_target_length = FLAGS.max_decode_length
        else:
            max_target_length = 0
        entity_matcher = get_entity_matcher(FLAGS.task, FLAGS.language)
        utterance_preprocessor = get_utterance_preprocessor(
            FLAGS.task, FLAGS.language)
        reader = GrammarCopyBasedDataReader(
            question_tokenizer,
            grammar,
            logical_form_preprocessor=logical_form_preprocessor,
            utterance_preprocessor=utterance_preprocessor,
            copy_link_finder=entity_matcher,
            maximum_target_length=max_target_length)
    elif FLAGS.model == 'translation':
        # Translation
        logical_form_tokenizer = get_logical_form_tokenizer(
            FLAGS.task, FLAGS.language)
        reader = Seq2SeqDataReader(
            question_tokenizer=question_tokenizer,
            logical_form_tokenizer=logical_form_tokenizer,
            is_parsing=False)
        return reader
    elif FLAGS.model == 'seq_parsing':
        # Parsing without grammar
        logical_form_tokenizer = get_logical_form_tokenizer(
            FLAGS.task, FLAGS.language)
        reader = Seq2SeqDataReader(
            question_tokenizer=question_tokenizer,
            logical_form_tokenizer=logical_form_tokenizer,
            is_parsing=True)
    elif FLAGS.model == 'recombination_seq_parsing':
        logical_form_preprocessor = get_logical_form_preprocessor(
            FLAGS.task,
            FLAGS.language,
            normalize_var_with_de_brujin_index=True)
        logical_form_tokenizer = get_logical_form_tokenizer(
            FLAGS.task, FLAGS.language)
        if FLAGS.do_train:
            max_target_length = FLAGS.max_decode_length
        else:
            max_target_length = 0
        reader = Seq2SeqDataReader(
            question_tokenizer=question_tokenizer,
            logical_form_tokenizer=logical_form_tokenizer,
            logical_form_preprocessor=logical_form_preprocessor,
            is_parsing=True,
            maximum_target_length=max_target_length)
        return reader
    elif FLAGS.model == 'recombination_copy_seq_parsing':
        logical_form_preprocessor = get_logical_form_preprocessor(
            FLAGS.task,
            FLAGS.language,
            normalize_var_with_de_brujin_index=True)
        logical_form_tokenizer = get_logical_form_tokenizer(
            FLAGS.task, FLAGS.language)
        if FLAGS.do_train:
            max_target_length = FLAGS.max_decode_length
        else:
            max_target_length = 0
        entity_matcher = get_seq2seq_entity_matcher(FLAGS.task, FLAGS.language)
        if FLAGS.language.startswith('sql'):
            exclude_target_words = [
                'select', 'from', 'and', 'in', 'where', 'group', 'order',
                'having', 'limit', 'not'
            ]
        else:
            exclude_target_words = None
        reader = Seq2SeqDataReader(
            question_tokenizer=question_tokenizer,
            logical_form_tokenizer=logical_form_tokenizer,
            logical_form_preprocessor=logical_form_preprocessor,
            is_parsing=True,
            enable_copy=True,
            maximum_target_length=max_target_length,
            entity_matcher=entity_matcher,
            exclude_target_words=exclude_target_words)
        return reader
    elif FLAGS.model in ['gnn_parsing', 'gnn_parsing2']:
        logical_form_preprocessor = get_logical_form_preprocessor(
            FLAGS.task,
            FLAGS.language,
            normalize_var_with_de_brujin_index=True)
        logical_form_tokenizer = get_logical_form_tokenizer(
            FLAGS.task, FLAGS.language)
        if FLAGS.do_train:
            max_target_length = FLAGS.max_decode_length
            allow_drop = True
        else:
            max_target_length = 0
            allow_drop = False
        grammar = get_grammar(FLAGS.task, FLAGS.language)
        entity_matcher = get_gnn_entity_matcher(FLAGS.task, FLAGS.language)
        entity_replacer = get_gnn_entity_replacer(FLAGS.task, FLAGS.language)
        reader = GNNCopyTransformerDataReader(
            entity_matcher=entity_matcher,
            entity_replacer=entity_replacer,
            target_grammar=grammar,
            source_tokenizer=question_tokenizer,
            target_tokenizer=logical_form_tokenizer,
            logical_form_preprocessor=logical_form_preprocessor,
            relative_position_clipped_range=FLAGS.
            gnn_relative_position_clipped_range,
            nlabels=FLAGS.gnn_transformer_num_edge_labels,
            allow_drop=allow_drop)
        return reader

    return reader
Пример #13
0
        code.append(
            """    compare_result_t%d <- (timeout %d (return $! (%s)))""" %
            (idx, timeout_limits, tq))
        code.append("""    print ("t", %d, compare_result_t%d)""" % (idx, idx))

    code = "\n".join(code)
    code = script_template % (code)
    with open('Main.hs', 'w') as f:
        f.write(code)

    # copy file
    shutil.copyfile('./Main.hs', './lambda_calculus/evaluator/app/Main.hs')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--predictions',
                        help='file that stores the prediction results',
                        required=True)
    parser.add_argument("--timeout",
                        help='timeout limit for expression',
                        default=120,
                        type=int)
    args = parser.parse_args()
    from grammars.grammar import get_grammar
    grammar = get_grammar('geo', 'lambda')
    from grammars.utils import get_logical_form_tokenizer
    tokenizer = get_logical_form_tokenizer('geo', 'lambda')
    evaluate(args.predictions, grammar, tokenizer, args.timeout)
    # calculuate_result(args.predictions)