def get_action_infos(src_query, tgt_actions, force_copy=False): action_infos = [] hyp = Hypothesis() for t, action in enumerate(tgt_actions): action_info = ActionInfo(action) action_info.t = t if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field if isinstance(action, GenTokenAction): try: tok_src_idx = src_query.index(str(action.token)) action_info.copy_from_src = True action_info.src_token_position = tok_src_idx except ValueError: if force_copy: raise ValueError( 'cannot copy primitive token %s from source' % action.token) hyp.apply_action(action) action_infos.append(action_info) return action_infos
def load_dataset(transition_system, dataset_file, reorder_predicates=True): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') lf = parse_lambda_expr(tgt_code) assert lf.to_string() == tgt_code if reorder_predicates: ordered_lf = get_canonical_order_of_logical_form( lf, order_by='alphabet') assert ordered_lf == lf lf = ordered_lf gold_source = lf.to_string() tgt_ast = logical_form_to_ast(grammar, lf) reconstructed_lf = ast_to_logical_form(tgt_ast) assert lf == reconstructed_lf tgt_actions = transition_system.get_actions(tgt_ast) print(idx) print('Utterance: %s' % src_query) print('Reference: %s' % tgt_code) # print('===== Actions =====') # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) # print(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert src_from_hyp == gold_source tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) # print(' '.join(src_query_tokens)) print('***') print(lf.to_string()) print() example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=gold_source, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def check(): data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl' engine = DBEngine('/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db') grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read()) transition_system = SqlTransitionSystem(grammar) from asdl.hypothesis import Hypothesis for line in open(data_file): example = json.loads(line) query = Query.from_dict(example['sql']) asdl_ast = sql_query_to_asdl_ast(query, grammar) asdl_ast.sanity_check() actions = transition_system.get_actions(asdl_ast) hyp = Hypothesis() for action in actions: hyp.apply_action(action) # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast): # hyp_query = asdl_ast_to_sql_query(hyp.tree) # make sure the execution result is the same # hyp_query_result = engine.execute_query(example['table_id'], hyp_query) # ref_result = engine.execute_query(example['table_id'], query) # assert hyp_query_result == ref_result query_reconstr = asdl_ast_to_sql_query(asdl_ast) assert query == query_reconstr print(query)
def check(): data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl' engine = DBEngine( '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db') grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read()) transition_system = SqlTransitionSystem(grammar) from asdl.hypothesis import Hypothesis for line in open(data_file): example = json.loads(line) query = Query.from_dict(example['sql']) asdl_ast = sql_query_to_asdl_ast(query, grammar) asdl_ast.sanity_check() actions = transition_system.get_actions(asdl_ast) hyp = Hypothesis() for action in actions: hyp.apply_action(action) # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast): # hyp_query = asdl_ast_to_sql_query(hyp.tree) # make sure the execution result is the same # hyp_query_result = engine.execute_query(example['table_id'], hyp_query) # ref_result = engine.execute_query(example['table_id'], query) # assert hyp_query_result == ref_result query_reconstr = asdl_ast_to_sql_query(asdl_ast) assert query == query_reconstr print(query)
def load_regex_dataset(transition_system, split): prefix = 'data/regex/' src_file = join(prefix, "src-{}.txt".format(split)) spec_file = join(prefix, "spec-{}.txt".format(split)) examples = [] for idx, (src_line, spec_line) in enumerate(zip(open(src_file), open(spec_file))): print(idx) src_line = src_line.rstrip() spec_line = spec_line.rstrip() src_toks = src_line.split() spec_toks = spec_line.rstrip().split() spec_ast = regex_expr_to_ast(transition_system.grammar, spec_toks) # sanity check reconstructed_expr = transition_system.ast_to_surface_code( spec_ast) print(spec_line, reconstructed_expr) assert spec_line == reconstructed_expr tgt_actions = transition_system.get_actions(spec_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, spec_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == spec_line tgt_action_infos = get_action_infos(src_toks, tgt_actions) example = Example(idx=idx, src_sent=src_toks, tgt_actions=tgt_action_infos, tgt_code=spec_line, tgt_ast=spec_ast, meta=None) examples.append(example) return examples
def get_action_infos(tgt_actions): action_infos = [] hyp = Hypothesis() for t, action in enumerate(tgt_actions): action_info = ActionInfo(action) action_info.t = t if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field hyp.apply_action(action) action_infos.append(action_info) return action_infos
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): print(line) src_query, tgt_code = line.strip().split('~') tgt_code = tgt_code.replace("("," ( ") tgt_code = tgt_code.replace(")"," ) ") tgt_code = " ".join(tgt_code.split()) src_query = src_query.replace("(","") src_query = src_query.replace(")","") src_query_tokens = src_query.split(' ') tgt_ast = lisp_expr_to_ast(transition_system.grammar, tgt_code) reconstructed_lisp_expr = ast_to_lisp_expr(tgt_ast) assert tgt_code == reconstructed_lisp_expr tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions(hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, tgt_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == tgt_code tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=tgt_code, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') lf = parse_lambda_expr(tgt_code) gold_source = lf.to_string() assert gold_source == tgt_code tgt_ast = logical_form_to_ast(grammar, lf) reconstructed_lf = ast_to_logical_form(tgt_ast) assert lf == reconstructed_lf tgt_actions = transition_system.get_actions(tgt_ast) print('Utterance: %s' % src_query) print('Reference: %s' % tgt_code) print('===== Actions =====') # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions(hyp) hyp = hyp.clone_and_apply_action(action) print(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert src_from_hyp == gold_source tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=gold_source, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') lf = parse_lambda_expr(tgt_code) gold_source = lf.to_string() assert gold_source == tgt_code tgt_ast = logical_form_to_ast(grammar, lf) reconstructed_lf = ast_to_logical_form(tgt_ast) assert lf == reconstructed_lf tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None src_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert src_from_hyp == gold_source tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=gold_source, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
def load_dataset(transition_system, dataset_file): examples = [] for idx, line in enumerate(open(dataset_file)): src_query, tgt_code = line.strip().split('\t') src_query_tokens = src_query.split(' ') tgt_ast = prolog_expr_to_ast(transition_system.grammar, tgt_code) reconstructed_prolog_expr = ast_to_prolog_expr(tgt_ast) assert tgt_code == reconstructed_prolog_expr tgt_actions = transition_system.get_actions(tgt_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types(hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions(hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, tgt_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == tgt_code tgt_action_infos = get_action_infos(src_query_tokens, tgt_actions) print(idx) example = Example(idx=idx, src_sent=src_query_tokens, tgt_actions=tgt_action_infos, tgt_code=tgt_code, tgt_ast=tgt_ast, meta=None) examples.append(example) return examples
if(True): from asdl.hypothesis import Hypothesis for ids,line in enumerate(open(data_file,encoding='utf-8')): example = json.loads(line) print(example['sql']) query = Query.from_dict(example['sql']).lower() print(query) asdl_ast = sql_query_to_asdl_ast(query, grammar) asdl_ast.sanity_check() print(asdl_ast.to_string()) # asdl_ast.sort_removedup_self() # print(asdl_ast.to_string()) # a=input('fff') actions = transition_system.get_actions(asdl_ast) tmp.append(actions) hyp = Hypothesis() print(actions) for action in actions: hyp.apply_action(action) print(hyp.tree) # a=input('fff') # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast): if(True): hyp_query = asdl_ast_to_sql_query(hyp.tree) # make sure the execution result is the same hyp_query_result = engine.execute_query(example['table_id'], hyp_query) ref_result = engine.execute_query(example['table_id'], query) print(query) print(ref_result) assert hyp_query_result == ref_result query_reconstr = asdl_ast_to_sql_query(asdl_ast)
def get_action_infos(src_query, tgt_actions, force_copy=False, copy_method='token'): action_infos = [] hyp = Hypothesis() t = 0 while t < len(tgt_actions): action = tgt_actions[t] if type(action) is GenTokenAction: begin_t = t t += 1 while t < len(tgt_actions) and type( tgt_actions[t]) is GenTokenAction: t += 1 end_t = t gen_token_actions = tgt_actions[begin_t:end_t] assert gen_token_actions[-1].is_stop_signal() tokens = [action.token for action in gen_token_actions[:-1]] try: tok_src_start_idx, tok_src_end_idx = find_sub_sequence( src_query, tokens) tok_src_idxs = list(range(tok_src_start_idx, tok_src_end_idx)) except IndexError: print('\tCannot find [%s] in [%s]' % (' '.join(tokens), ' '.join(src_query)), file=sys.stderr) tok_src_idxs = [src_query.index(token) for token in tokens] tok_src_idxs.append(-1) # for </primitive> for tok_src_idx, gen_token_action in zip(tok_src_idxs, gen_token_actions): action_info = ActionInfo(gen_token_action) if not gen_token_action.is_stop_signal(): action_info.copy_from_src = True action_info.src_token_position = tok_src_idx assert src_query[tok_src_idx] == gen_token_action.token if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field hyp.apply_action(gen_token_action) action_infos.append(action_info) else: action_info = ActionInfo(action) if hyp.frontier_node: action_info.parent_t = hyp.frontier_node.created_time action_info.frontier_prod = hyp.frontier_node.production action_info.frontier_field = hyp.frontier_field.field hyp.apply_action(action) action_infos.append(action_info) t += 1 # for t, action in enumerate(tgt_actions): # action_info = ActionInfo(action) # action_info.t = t # if hyp.frontier_node: # action_info.parent_t = hyp.frontier_node.created_time # action_info.frontier_prod = hyp.frontier_node.production # action_info.frontier_field = hyp.frontier_field.field # # if type(action) is GenTokenAction: # try: # tok_src_idx = src_query.index(str(action.token)) # action_info.copy_from_src = True # action_info.src_token_position = tok_src_idx # except ValueError: # if force_copy and not action.is_stop_signal(): # raise ValueError('cannot copy primitive token %s from source' % action.token) # # hyp.apply_action(action) # action_infos.append(action_info) return action_infos
def load_dataset(transition_system, dataset_file, table_file): examples = [] engine = DBEngine(dataset_file[:-len('jsonl')] + 'db') # load table tables = dict() for line in open(table_file): table_entry = json.loads(line) tables[table_entry['id']] = table_entry for idx, line in tqdm.tqdm(enumerate(open(dataset_file))): # if idx > 100: break entry = json.loads(line) del entry['seq_input'] del entry['seq_output'] del entry['where_output'] query = Query.from_tokenized_dict(entry['query']) query = query.lower() tokenized_conditions = [] for col, op, val_entry in entry['query']['conds']: val = [] for word, after in zip(val_entry['words'], val_entry['after']): val.append(word) tokenized_conditions.append([col, op, ' '.join(val)]) tokenized_query = Query(sel_index=entry['query']['sel'], agg_index=entry['query']['agg'], conditions=tokenized_conditions) asdl_ast = sql_query_to_asdl_ast(tokenized_query, transition_system.grammar) asdl_ast.sanity_check() actions = transition_system.get_actions(asdl_ast) hyp = Hypothesis() question_tokens = entry['question']['words'] tgt_action_infos = get_action_infos(question_tokens, actions, force_copy=True) for action, action_info in zip(actions, tgt_action_infos): assert action == action_info.action hyp.apply_action(action) reconstructed_query_from_hyp = asdl_ast_to_sql_query(hyp.tree) reconstructed_query = asdl_ast_to_sql_query(asdl_ast) assert tokenized_query == reconstructed_query # # # now we make sure the tokenized query executes to the same results as the original one! # # detokenized_conds_from_reconstr_query = [] # error = False # for i, (col, op, val) in enumerate(reconstructed_query_from_hyp.conditions): # val_tokens = val.split(' ') # cond_entry = entry['query']['conds'][i] # # assert col == cond_entry[0] # assert op == cond_entry[1] # # detokenized_cond_val = my_detokenize(val_tokens, entry['question']) # raw_cond_val = detokenize(cond_entry[2]) # if detokenized_cond_val.lower() != raw_cond_val.lower(): # # print(idx + 1, detokenized_cond_val, raw_cond_val, file=sys.stderr) # error = True # # detokenized_conds_from_reconstr_query.append((col, op, detokenized_cond_val)) # # detokenized_reconstr_query_from_hyp = Query(sel_index=reconstructed_query_from_hyp.sel_index, # agg_index=reconstructed_query_from_hyp.agg_index, # conditions=detokenized_conds_from_reconstr_query) # # # make sure the execution result is the same # hyp_query_result = engine.execute_query(entry['table_id'], detokenized_reconstr_query_from_hyp) # ref_result = engine.execute_query(entry['table_id'], query) # # if hyp_query_result != ref_result: # print('[%d]: %s, %s' % (idx, query, detokenized_reconstr_query_from_hyp), file=sys.stderr) header = [ TableColumn(name=detokenize(col_name), tokens=col_name['words'], type=col_type) for (col_name, col_type) in zip(entry['table']['header'], tables[entry['table_id']]['types']) ] table = WikiSqlTable(header=header) example = WikiSqlExample(idx=idx, question=question_tokens, table=table, tgt_actions=tgt_action_infos, tgt_code=query, tgt_ast=asdl_ast, meta=entry) examples.append(example) # print(query) return examples
def load_regex_dataset(transition_system, split): prefix = 'data/streg/' src_file = join(prefix, "src-{}.txt".format(split)) spec_file = join(prefix, "targ-{}.txt".format(split)) map_file = join(prefix, "map-{}.txt".format(split)) exs_file = join(prefix, "exs-{}.txt".format(split)) rec_file = join(prefix, "rec-{}.pkl".format(split)) exs_info = StReg.load_examples(exs_file) map_info = StReg.load_map_file(map_file) rec_info = StReg.load_rec(rec_file) examples = [] for idx, (src_line, spec_line, str_exs, cmap, rec) in enumerate( zip(open(src_file), open(spec_file), exs_info, map_info, rec_info)): print(idx) src_line = src_line.rstrip() spec_line = spec_line.rstrip() src_toks = src_line.split() spec_toks = spec_line.rstrip().split() spec_ast = streg_expr_to_ast(transition_system.grammar, spec_toks) # sanity check reconstructed_expr = transition_system.ast_to_surface_code( spec_ast) # print("Spec", spec_line) # print("Rcon", reconstructed_expr) assert spec_line == reconstructed_expr tgt_actions = transition_system.get_actions(spec_ast) # sanity check hyp = Hypothesis() for action in tgt_actions: assert action.__class__ in transition_system.get_valid_continuation_types( hyp) if isinstance(action, ApplyRuleAction): assert action.production in transition_system.get_valid_continuating_productions( hyp) hyp = hyp.clone_and_apply_action(action) assert hyp.frontier_node is None and hyp.frontier_field is None assert is_equal_ast(hyp.tree, spec_ast) expr_from_hyp = transition_system.ast_to_surface_code(hyp.tree) assert expr_from_hyp == spec_line tgt_action_infos = get_action_infos(src_toks, tgt_actions) example = Example(idx=idx, src_sent=src_toks, tgt_actions=tgt_action_infos, tgt_code=spec_line, tgt_ast=spec_ast, meta={ "str_exs": str_exs, "const_map": cmap, "worker_info": rec }) examples.append(example) return examples