コード例 #1
0
def get_action_infos(src_query, tgt_actions, force_copy=False):
    action_infos = []
    hyp = Hypothesis()
    for t, action in enumerate(tgt_actions):
        action_info = ActionInfo(action)
        action_info.t = t
        if hyp.frontier_node:
            action_info.parent_t = hyp.frontier_node.created_time
            action_info.frontier_prod = hyp.frontier_node.production
            action_info.frontier_field = hyp.frontier_field.field

        if isinstance(action, GenTokenAction):
            try:
                tok_src_idx = src_query.index(str(action.token))
                action_info.copy_from_src = True
                action_info.src_token_position = tok_src_idx
            except ValueError:
                if force_copy:
                    raise ValueError(
                        'cannot copy primitive token %s from source' %
                        action.token)

        hyp.apply_action(action)
        action_infos.append(action_info)

    return action_infos
コード例 #2
0
def load_dataset(transition_system, dataset_file, reorder_predicates=True):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        lf = parse_lambda_expr(tgt_code)
        assert lf.to_string() == tgt_code

        if reorder_predicates:
            ordered_lf = get_canonical_order_of_logical_form(
                lf, order_by='alphabet')
            assert ordered_lf == lf
            lf = ordered_lf

        gold_source = lf.to_string()

        tgt_ast = logical_form_to_ast(grammar, lf)
        reconstructed_lf = ast_to_logical_form(tgt_ast)
        assert lf == reconstructed_lf

        tgt_actions = transition_system.get_actions(tgt_ast)

        print(idx)
        print('Utterance: %s' % src_query)
        print('Reference: %s' % tgt_code)
        # print('===== Actions =====')
        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)
            hyp = hyp.clone_and_apply_action(action)
            # print(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        src_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert src_from_hyp == gold_source

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        # print(' '.join(src_query_tokens))
        print('***')
        print(lf.to_string())
        print()
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=gold_source,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
コード例 #3
0
def check():
    data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl'
    engine = DBEngine('/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db')
    grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)
    from asdl.hypothesis import Hypothesis
    for line in open(data_file):
        example = json.loads(line)
        query = Query.from_dict(example['sql'])
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        for action in actions:
            hyp.apply_action(action)

        # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast):
        #     hyp_query = asdl_ast_to_sql_query(hyp.tree)
            # make sure the execution result is the same
            # hyp_query_result = engine.execute_query(example['table_id'], hyp_query)
            # ref_result = engine.execute_query(example['table_id'], query)

            # assert hyp_query_result == ref_result
        query_reconstr = asdl_ast_to_sql_query(asdl_ast)
        assert query == query_reconstr
        print(query)
コード例 #4
0
def check():
    data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl'
    engine = DBEngine(
        '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db')
    grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)
    from asdl.hypothesis import Hypothesis
    for line in open(data_file):
        example = json.loads(line)
        query = Query.from_dict(example['sql'])
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        for action in actions:
            hyp.apply_action(action)

        # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast):
        #     hyp_query = asdl_ast_to_sql_query(hyp.tree)
        # make sure the execution result is the same
        # hyp_query_result = engine.execute_query(example['table_id'], hyp_query)
        # ref_result = engine.execute_query(example['table_id'], query)

        # assert hyp_query_result == ref_result
        query_reconstr = asdl_ast_to_sql_query(asdl_ast)
        assert query == query_reconstr
        print(query)
コード例 #5
0
    def load_regex_dataset(transition_system, split):
        prefix = 'data/regex/'
        src_file = join(prefix, "src-{}.txt".format(split))
        spec_file = join(prefix, "spec-{}.txt".format(split))

        examples = []
        for idx, (src_line,
                  spec_line) in enumerate(zip(open(src_file),
                                              open(spec_file))):
            print(idx)

            src_line = src_line.rstrip()
            spec_line = spec_line.rstrip()
            src_toks = src_line.split()

            spec_toks = spec_line.rstrip().split()
            spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks)

            # sanity check
            reconstructed_expr = transition_system.ast_to_surface_code(
                spec_ast)
            print(spec_line, reconstructed_expr)
            assert spec_line == reconstructed_expr

            tgt_actions = transition_system.get_actions(spec_ast)

            # sanity check
            hyp = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            assert is_equal_ast(hyp.tree, spec_ast)

            expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
            assert expr_from_hyp == spec_line

            tgt_action_infos = get_action_infos(src_toks, tgt_actions)
            example = Example(idx=idx,
                              src_sent=src_toks,
                              tgt_actions=tgt_action_infos,
                              tgt_code=spec_line,
                              tgt_ast=spec_ast,
                              meta=None)

            examples.append(example)
        return examples
コード例 #6
0
def get_action_infos(tgt_actions):
    action_infos = []
    hyp = Hypothesis()
    for t, action in enumerate(tgt_actions):
        action_info = ActionInfo(action)
        action_info.t = t
        if hyp.frontier_node:
            action_info.parent_t = hyp.frontier_node.created_time
            action_info.frontier_prod = hyp.frontier_node.production
            action_info.frontier_field = hyp.frontier_field.field

        hyp.apply_action(action)
        action_infos.append(action_info)

    return action_infos
コード例 #7
0
ファイル: dataset.py プロジェクト: Amirutha/tranX
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        print(line)
        src_query, tgt_code = line.strip().split('~')

        tgt_code = tgt_code.replace("("," ( ")
        tgt_code = tgt_code.replace(")"," ) ")
        tgt_code = " ".join(tgt_code.split())
        src_query = src_query.replace("(","")
        src_query = src_query.replace(")","")
        src_query_tokens = src_query.split(' ')

        tgt_ast = lisp_expr_to_ast(transition_system.grammar, tgt_code)
        reconstructed_lisp_expr = ast_to_lisp_expr(tgt_ast)
        assert tgt_code == reconstructed_lisp_expr

        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(hyp)
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        assert is_equal_ast(hyp.tree, tgt_ast)

        expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert expr_from_hyp == tgt_code

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=tgt_code,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
コード例 #8
0
ファイル: dataset.py プロジェクト: chubbymaggie/tranX
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        lf = parse_lambda_expr(tgt_code)
        gold_source = lf.to_string()
        assert gold_source == tgt_code
        tgt_ast = logical_form_to_ast(grammar, lf)
        reconstructed_lf = ast_to_logical_form(tgt_ast)
        assert lf == reconstructed_lf

        tgt_actions = transition_system.get_actions(tgt_ast)

        print('Utterance: %s' % src_query)
        print('Reference: %s' % tgt_code)
        print('===== Actions =====')
        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(hyp)
            hyp = hyp.clone_and_apply_action(action)
            print(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        src_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert src_from_hyp == gold_source

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=gold_source,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
コード例 #9
0
ファイル: dataset.py プロジェクト: zkcpku/structVAE
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        lf = parse_lambda_expr(tgt_code)
        gold_source = lf.to_string()
        assert gold_source == tgt_code
        tgt_ast = logical_form_to_ast(grammar, lf)
        reconstructed_lf = ast_to_logical_form(tgt_ast)
        assert lf == reconstructed_lf

        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(
                hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(
                    hyp)
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        src_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert src_from_hyp == gold_source

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=gold_source,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
コード例 #10
0
ファイル: dataset.py プロジェクト: chubbymaggie/tranX
def load_dataset(transition_system, dataset_file):
    examples = []
    for idx, line in enumerate(open(dataset_file)):
        src_query, tgt_code = line.strip().split('\t')

        src_query_tokens = src_query.split(' ')

        tgt_ast = prolog_expr_to_ast(transition_system.grammar, tgt_code)
        reconstructed_prolog_expr = ast_to_prolog_expr(tgt_ast)
        assert tgt_code == reconstructed_prolog_expr

        tgt_actions = transition_system.get_actions(tgt_ast)

        # sanity check
        hyp = Hypothesis()
        for action in tgt_actions:
            assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
            if isinstance(action, ApplyRuleAction):
                assert action.production in transition_system.get_valid_continuating_productions(hyp)
            hyp = hyp.clone_and_apply_action(action)

        assert hyp.frontier_node is None and hyp.frontier_field is None

        assert is_equal_ast(hyp.tree, tgt_ast)

        expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
        assert expr_from_hyp == tgt_code

        tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions)

        print(idx)
        example = Example(idx=idx,
                          src_sent=src_query_tokens,
                          tgt_actions=tgt_action_infos,
                          tgt_code=tgt_code,
                          tgt_ast=tgt_ast,
                          meta=None)

        examples.append(example)

    return examples
コード例 #11
0
if(True):
    from asdl.hypothesis import Hypothesis
    for ids,line in enumerate(open(data_file,encoding='utf-8')):
        example = json.loads(line)
        print(example['sql'])
        query = Query.from_dict(example['sql']).lower()
        print(query)
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        print(asdl_ast.to_string())
#        asdl_ast.sort_removedup_self()
#        print(asdl_ast.to_string())
#        a=input('fff')
        actions = transition_system.get_actions(asdl_ast)
        tmp.append(actions)
        hyp = Hypothesis()
        print(actions)
        for action in actions:
            hyp.apply_action(action)
        print(hyp.tree)
#        a=input('fff')
#        if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast):
        if(True):
             hyp_query = asdl_ast_to_sql_query(hyp.tree)
#             make sure the execution result is the same
             hyp_query_result = engine.execute_query(example['table_id'], hyp_query)
             ref_result = engine.execute_query(example['table_id'], query)
             print(query)
             print(ref_result)
             assert hyp_query_result == ref_result
        query_reconstr = asdl_ast_to_sql_query(asdl_ast)
コード例 #12
0
ファイル: dataset.py プロジェクト: tomsonsgs/TRAN-MMA-master
def get_action_infos(src_query,
                     tgt_actions,
                     force_copy=False,
                     copy_method='token'):
    action_infos = []
    hyp = Hypothesis()
    t = 0
    while t < len(tgt_actions):
        action = tgt_actions[t]

        if type(action) is GenTokenAction:
            begin_t = t
            t += 1
            while t < len(tgt_actions) and type(
                    tgt_actions[t]) is GenTokenAction:
                t += 1
            end_t = t

            gen_token_actions = tgt_actions[begin_t:end_t]
            assert gen_token_actions[-1].is_stop_signal()

            tokens = [action.token for action in gen_token_actions[:-1]]

            try:
                tok_src_start_idx, tok_src_end_idx = find_sub_sequence(
                    src_query, tokens)
                tok_src_idxs = list(range(tok_src_start_idx, tok_src_end_idx))
            except IndexError:
                print('\tCannot find [%s] in [%s]' %
                      (' '.join(tokens), ' '.join(src_query)),
                      file=sys.stderr)
                tok_src_idxs = [src_query.index(token) for token in tokens]

            tok_src_idxs.append(-1)  # for </primitive>

            for tok_src_idx, gen_token_action in zip(tok_src_idxs,
                                                     gen_token_actions):
                action_info = ActionInfo(gen_token_action)

                if not gen_token_action.is_stop_signal():
                    action_info.copy_from_src = True
                    action_info.src_token_position = tok_src_idx

                    assert src_query[tok_src_idx] == gen_token_action.token

                if hyp.frontier_node:
                    action_info.parent_t = hyp.frontier_node.created_time
                    action_info.frontier_prod = hyp.frontier_node.production
                    action_info.frontier_field = hyp.frontier_field.field

                hyp.apply_action(gen_token_action)
                action_infos.append(action_info)
        else:
            action_info = ActionInfo(action)
            if hyp.frontier_node:
                action_info.parent_t = hyp.frontier_node.created_time
                action_info.frontier_prod = hyp.frontier_node.production
                action_info.frontier_field = hyp.frontier_field.field

            hyp.apply_action(action)
            action_infos.append(action_info)
            t += 1

    # for t, action in enumerate(tgt_actions):
    #     action_info = ActionInfo(action)
    #     action_info.t = t
    #     if hyp.frontier_node:
    #         action_info.parent_t = hyp.frontier_node.created_time
    #         action_info.frontier_prod = hyp.frontier_node.production
    #         action_info.frontier_field = hyp.frontier_field.field
    #
    #     if type(action) is GenTokenAction:
    #         try:
    #             tok_src_idx = src_query.index(str(action.token))
    #             action_info.copy_from_src = True
    #             action_info.src_token_position = tok_src_idx
    #         except ValueError:
    #             if force_copy and not action.is_stop_signal():
    #                 raise ValueError('cannot copy primitive token %s from source' % action.token)
    #
    #     hyp.apply_action(action)
    #     action_infos.append(action_info)

    return action_infos
コード例 #13
0
ファイル: dataset.py プロジェクト: tomsonsgs/TRAN-MMA-master
def load_dataset(transition_system, dataset_file, table_file):
    examples = []
    engine = DBEngine(dataset_file[:-len('jsonl')] + 'db')

    # load table
    tables = dict()
    for line in open(table_file):
        table_entry = json.loads(line)
        tables[table_entry['id']] = table_entry

    for idx, line in tqdm.tqdm(enumerate(open(dataset_file))):
        # if idx > 100: break
        entry = json.loads(line)
        del entry['seq_input']
        del entry['seq_output']
        del entry['where_output']

        query = Query.from_tokenized_dict(entry['query'])
        query = query.lower()

        tokenized_conditions = []
        for col, op, val_entry in entry['query']['conds']:
            val = []
            for word, after in zip(val_entry['words'], val_entry['after']):
                val.append(word)

            tokenized_conditions.append([col, op, ' '.join(val)])
        tokenized_query = Query(sel_index=entry['query']['sel'],
                                agg_index=entry['query']['agg'],
                                conditions=tokenized_conditions)

        asdl_ast = sql_query_to_asdl_ast(tokenized_query,
                                         transition_system.grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        question_tokens = entry['question']['words']
        tgt_action_infos = get_action_infos(question_tokens,
                                            actions,
                                            force_copy=True)

        for action, action_info in zip(actions, tgt_action_infos):
            assert action == action_info.action
            hyp.apply_action(action)

        reconstructed_query_from_hyp = asdl_ast_to_sql_query(hyp.tree)
        reconstructed_query = asdl_ast_to_sql_query(asdl_ast)

        assert tokenized_query == reconstructed_query
        #
        #        # now we make sure the tokenized query executes to the same results as the original one!
        #
        #        detokenized_conds_from_reconstr_query = []
        #        error = False
        #        for i, (col, op, val) in enumerate(reconstructed_query_from_hyp.conditions):
        #            val_tokens = val.split(' ')
        #            cond_entry = entry['query']['conds'][i]
        #
        #            assert col == cond_entry[0]
        #            assert op == cond_entry[1]
        #
        #            detokenized_cond_val = my_detokenize(val_tokens, entry['question'])
        #            raw_cond_val = detokenize(cond_entry[2])
        #            if detokenized_cond_val.lower() != raw_cond_val.lower():
        #                # print(idx + 1, detokenized_cond_val, raw_cond_val, file=sys.stderr)
        #                error = True
        #
        #            detokenized_conds_from_reconstr_query.append((col, op, detokenized_cond_val))
        #
        #        detokenized_reconstr_query_from_hyp = Query(sel_index=reconstructed_query_from_hyp.sel_index,
        #                                                    agg_index=reconstructed_query_from_hyp.agg_index,
        #                                                    conditions=detokenized_conds_from_reconstr_query)
        #
        #        # make sure the execution result is the same
        #        hyp_query_result = engine.execute_query(entry['table_id'], detokenized_reconstr_query_from_hyp)
        #        ref_result = engine.execute_query(entry['table_id'], query)
        #
        #        if hyp_query_result != ref_result:
        #            print('[%d]: %s, %s' % (idx, query, detokenized_reconstr_query_from_hyp), file=sys.stderr)

        header = [
            TableColumn(name=detokenize(col_name),
                        tokens=col_name['words'],
                        type=col_type) for (col_name, col_type) in
            zip(entry['table']['header'], tables[entry['table_id']]['types'])
        ]
        table = WikiSqlTable(header=header)

        example = WikiSqlExample(idx=idx,
                                 question=question_tokens,
                                 table=table,
                                 tgt_actions=tgt_action_infos,
                                 tgt_code=query,
                                 tgt_ast=asdl_ast,
                                 meta=entry)

        examples.append(example)

        # print(query)

    return examples
コード例 #14
0
    def load_regex_dataset(transition_system, split):
        prefix = 'data/streg/'
        src_file = join(prefix, "src-{}.txt".format(split))
        spec_file = join(prefix, "targ-{}.txt".format(split))
        map_file = join(prefix, "map-{}.txt".format(split))
        exs_file = join(prefix, "exs-{}.txt".format(split))
        rec_file = join(prefix, "rec-{}.pkl".format(split))

        exs_info = StReg.load_examples(exs_file)
        map_info = StReg.load_map_file(map_file)
        rec_info = StReg.load_rec(rec_file)

        examples = []
        for idx, (src_line, spec_line, str_exs, cmap, rec) in enumerate(
                zip(open(src_file), open(spec_file), exs_info, map_info,
                    rec_info)):
            print(idx)

            src_line = src_line.rstrip()
            spec_line = spec_line.rstrip()
            src_toks = src_line.split()

            spec_toks = spec_line.rstrip().split()
            spec_ast = streg_expr_to_ast(transition_system.grammar, spec_toks)
            # sanity check
            reconstructed_expr = transition_system.ast_to_surface_code(
                spec_ast)
            # print("Spec", spec_line)
            # print("Rcon", reconstructed_expr)
            assert spec_line == reconstructed_expr

            tgt_actions = transition_system.get_actions(spec_ast)
            # sanity check
            hyp = Hypothesis()
            for action in tgt_actions:
                assert action.__class__ in transition_system.get_valid_continuation_types(
                    hyp)
                if isinstance(action, ApplyRuleAction):
                    assert action.production in transition_system.get_valid_continuating_productions(
                        hyp)
                hyp = hyp.clone_and_apply_action(action)

            assert hyp.frontier_node is None and hyp.frontier_field is None
            assert is_equal_ast(hyp.tree, spec_ast)

            expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree)
            assert expr_from_hyp == spec_line

            tgt_action_infos = get_action_infos(src_toks, tgt_actions)

            example = Example(idx=idx,
                              src_sent=src_toks,
                              tgt_actions=tgt_action_infos,
                              tgt_code=spec_line,
                              tgt_ast=spec_ast,
                              meta={
                                  "str_exs": str_exs,
                                  "const_map": cmap,
                                  "worker_info": rec
                              })
            examples.append(example)
        return examples