def process_ex(self, ex): query_tokens = None if 'query_toks' in ex: # we only have 'query_toks' in example for training/dev sets # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good # since the official evaluator does expect a number and not a value ex = fix_number_value(ex) # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to, # and for each table alias its explicit name) # we thus remove all aliases and make changes such as: # 'name' -> 'singer@name', # 'singer AS T1' -> 'singer', # 'T1.name' -> 'singer@name' try: query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'], self._tables_file, allow_aliases=False) except Exception as e: # there are two examples in the train set that are wrongly formatted, skip them print(f"error with {ex['query']}") print(e) ins = self.text_to_instance( utterance=ex['question'], db_id=ex['db_id'], sql=query_tokens) if ins is not None: return ins
def process_ex(self, total_cnt, ex, cache_filepath): utterances = [] sql = [] for step in ex['interaction']: utterances.append(step['utterance']) step['query_toks_no_value'] = sql_tokenize( re.sub(r"\'([^\']*)\'|\"([^\"]*)\"", r'value', step['query'].lower())) step['query_toks'] = sql_tokenize(step['query'].lower()) query_tokens = [] for tok in step['query_toks_no_value']: query_tokens += tok.split(' ') step['query_toks_no_value'] = query_tokens try: fix_number_value(step) except Exception: pass step['query_toks_no_value'] = disambiguate_items( ex['database_id'], step['query_toks_no_value'], self._tables_file, allow_aliases=False) sql.append(step['query_toks_no_value']) ins = self.text_to_instance(utterances=utterances, db_id=ex['database_id'], sql=sql) if self._save_cache: dill.dump(ins, open(cache_filepath, 'wb')) return ins
def _read(self, file_path: str): if not file_path.endswith('.json'): raise ConfigurationError( f"Don't know how to read filetype of {file_path}") cache_dir = os.path.join('../model/cache', file_path.split("/")[-1]) if self._load_cache: logger.info(f'Trying to load cache from {cache_dir}') if self._save_cache: os.makedirs(cache_dir, exist_ok=True) cnt = 0 with open(file_path, "r") as data_file: json_obj = json.load(data_file) for total_cnt, ex in enumerate(json_obj): cache_filename = f'instance-{total_cnt}.pt' cache_filepath = os.path.join(cache_dir, cache_filename) if self._loading_limit == cnt: break if self._load_cache: try: ins = dill.load(open(cache_filepath, 'rb')) if ins is None and not self._keep_if_unparsable: # skip unparsed examples continue yield ins cnt += 1 except Exception as e: # could not load from cache - keep loading without cache pass query_tokens = None if 'query_toks' in ex: # we only have 'query_toks' in example for training/dev sets # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good # since the official evaluator does expect a number and not a value ex = fix_number_value(ex) # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to, # and for each table alias its explicit name) # we thus remove all aliases and make changes such as: # 'name' -> 'singer@name', # 'singer AS T1' -> 'singer', # 'T1.name' -> 'singer@name' try: query_tokens = disambiguate_items( ex['db_id'], ex['query_toks_no_value'], self._tables_file, allow_aliases=False) except Exception as e: # there are two examples in the train set that are wrongly formatted, skip them print(f"error with {ex['query']}") print(e) unsup = False if file_path.split("/")[-1] == "train_spider.json": if total_cnt % 10 == 0: self._keep_if_unparsable = self._keep_if_unparsable_original else: self._keep_if_unparsable = True query_tokens = None unsup = True ins = self.text_to_instance(utterance=ex['question'], db_id=ex['db_id'], sql=query_tokens, unsup=unsup) if ins is not None: cnt += 1 if self._save_cache: dill.dump(ins, open(cache_filepath, 'wb')) if ins is not None: yield ins
import json from dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items data_path = 'dataset/train_spider.json' tables_file = 'dataset/tables.json' with open(data_path, "r") as data_file: json_obj = json.load(data_file) for total_cnt, ex in enumerate(json_obj): if 'query_toks' in ex: ex = fix_number_value(ex) try: query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'], tables_file, allow_aliases=False) except Exception as e: # there are two examples in the train set that are wrongly formatted, skip them print(f"error with {ex['query']}") print(e)