Ejemplo n.º 1
0
    def process_ex(self, ex):
        query_tokens = None
        if 'query_toks' in ex:
            # we only have 'query_toks' in example for training/dev sets

            # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes
            # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good
            # since the official evaluator does expect a number and not a value
            ex = fix_number_value(ex)

            # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to,
            # and for each table alias its explicit name)
            # we thus remove all aliases and make changes such as:
            # 'name' -> 'singer@name',
            # 'singer AS T1' -> 'singer',
            # 'T1.name' -> 'singer@name'
            try:
                query_tokens = disambiguate_items(ex['db_id'], ex['query_toks_no_value'],
                                                        self._tables_file, allow_aliases=False)
            except Exception as e:
                # there are two examples in the train set that are wrongly formatted, skip them
                print(f"error with {ex['query']}")
                print(e)

        ins = self.text_to_instance(
            utterance=ex['question'],
            db_id=ex['db_id'],
            sql=query_tokens)

        if ins is not None:
            return ins
Ejemplo n.º 2
0
    def process_ex(self, total_cnt, ex, cache_filepath):
        utterances = []
        sql = []

        for step in ex['interaction']:
            utterances.append(step['utterance'])

            step['query_toks_no_value'] = sql_tokenize(
                re.sub(r"\'([^\']*)\'|\"([^\"]*)\"", r'value',
                       step['query'].lower()))
            step['query_toks'] = sql_tokenize(step['query'].lower())
            query_tokens = []
            for tok in step['query_toks_no_value']:
                query_tokens += tok.split(' ')
            step['query_toks_no_value'] = query_tokens
            try:
                fix_number_value(step)
            except Exception:
                pass

            step['query_toks_no_value'] = disambiguate_items(
                ex['database_id'],
                step['query_toks_no_value'],
                self._tables_file,
                allow_aliases=False)

            sql.append(step['query_toks_no_value'])

        ins = self.text_to_instance(utterances=utterances,
                                    db_id=ex['database_id'],
                                    sql=sql)

        if self._save_cache:
            dill.dump(ins, open(cache_filepath, 'wb'))

        return ins
Ejemplo n.º 3
0
    def _read(self, file_path: str):
        if not file_path.endswith('.json'):
            raise ConfigurationError(
                f"Don't know how to read filetype of {file_path}")

        cache_dir = os.path.join('../model/cache', file_path.split("/")[-1])

        if self._load_cache:
            logger.info(f'Trying to load cache from {cache_dir}')
        if self._save_cache:
            os.makedirs(cache_dir, exist_ok=True)

        cnt = 0
        with open(file_path, "r") as data_file:
            json_obj = json.load(data_file)
            for total_cnt, ex in enumerate(json_obj):
                cache_filename = f'instance-{total_cnt}.pt'
                cache_filepath = os.path.join(cache_dir, cache_filename)
                if self._loading_limit == cnt:
                    break

                if self._load_cache:
                    try:
                        ins = dill.load(open(cache_filepath, 'rb'))
                        if ins is None and not self._keep_if_unparsable:
                            # skip unparsed examples
                            continue
                        yield ins
                        cnt += 1
                    except Exception as e:
                        # could not load from cache - keep loading without cache
                        pass

                query_tokens = None
                if 'query_toks' in ex:
                    # we only have 'query_toks' in example for training/dev sets

                    # fix for examples: we want to use the 'query_toks_no_value' field of the example which anonymizes
                    # values. However, it also anonymizes numbers (e.g. LIMIT 3 -> LIMIT 'value', which is not good
                    # since the official evaluator does expect a number and not a value
                    ex = fix_number_value(ex)

                    # we want the query tokens to be non-ambiguous (i.e. know for each column the table it belongs to,
                    # and for each table alias its explicit name)
                    # we thus remove all aliases and make changes such as:
                    # 'name' -> 'singer@name',
                    # 'singer AS T1' -> 'singer',
                    # 'T1.name' -> 'singer@name'
                    try:
                        query_tokens = disambiguate_items(
                            ex['db_id'],
                            ex['query_toks_no_value'],
                            self._tables_file,
                            allow_aliases=False)
                    except Exception as e:
                        # there are two examples in the train set that are wrongly formatted, skip them
                        print(f"error with {ex['query']}")
                        print(e)

                unsup = False
                if file_path.split("/")[-1] == "train_spider.json":
                    if total_cnt % 10 == 0:
                        self._keep_if_unparsable = self._keep_if_unparsable_original
                    else:
                        self._keep_if_unparsable = True
                        query_tokens = None
                        unsup = True
                ins = self.text_to_instance(utterance=ex['question'],
                                            db_id=ex['db_id'],
                                            sql=query_tokens,
                                            unsup=unsup)
                if ins is not None:
                    cnt += 1
                if self._save_cache:
                    dill.dump(ins, open(cache_filepath, 'wb'))

                if ins is not None:
                    yield ins
import json
from dataset_readers.dataset_util.spider_utils import fix_number_value, disambiguate_items

data_path = 'dataset/train_spider.json'
tables_file = 'dataset/tables.json'

with open(data_path, "r") as data_file:

    json_obj = json.load(data_file)

    for total_cnt, ex in enumerate(json_obj):

        if 'query_toks' in ex:
            ex = fix_number_value(ex)

            try:
                query_tokens = disambiguate_items(ex['db_id'],
                                                  ex['query_toks_no_value'],
                                                  tables_file,
                                                  allow_aliases=False)
            except Exception as e:
                # there are two examples in the train set that are wrongly formatted, skip them
                print(f"error with {ex['query']}")
                print(e)