Ejemplo n.º 1
0
 def execute_query(cls, db_root, query, database_schemas, silent=False):
     db_id = query['db_id']
     db_path = os.path.join(db_root, db_id, db_id + '.sqlite')
     query_recov = query['recov']
     if 't5' in query_recov:
         return None
     query['query_toks'] = query_recov.replace('.', ' . ').split()
     query['query_toks_no_value'] = editsql_postprocess.postprocess_one(' '.join(query['column_mapped']), database_schemas[db_id]).replace('limit 1', 'limit_value').replace(' 1', ' value').replace('.', ' . ').split(' ')
     schema = evaluation.Schema(evaluation.get_schema(db_path))
     g_raw_res = timed_execute(db_path, query_recov, timeout=1, sleep=0.001, silent=silent)
     return g_raw_res
Ejemplo n.º 2
0
    def compute_official_eval(self, dev, dev_preds):
        metrics = dict(official_em=0, official_ex=0)
        for ex in tqdm.tqdm(dev, desc='official eval'):
            p = dev_preds[ex['id']]
            g_str = ex['g_query']
            db_name = ex['db_id']
            db = os.path.join('data', 'database', db_name, db_name + ".sqlite")
            p_str = dev_preds[ex['id']]['query']
            # fix spacing
            spacing = [
                ('` ` ', '"'),
                ("''", '"'),
                ('> =', '>='),
                ('< =', '<='),
                ("'% ", "'%"),
                (" %'", "%'"),
            ]
            for f, t in spacing:
                p_str = p_str.replace(f, t)
            # recover casing
            for v in ex['g_values']:
                v = self.bert_tokenizer.convert_tokens_to_string(v).strip(
                    ' ' + string.punctuation)
                p_str = p_str.replace(v.lower(), v)
            schema = evaluation.Schema(evaluation.get_schema(db))

            p_sql = self.build_sql(schema, p_str, self.kmaps[ex['db_id']])

            # the offical eval script is buggy and modifies arguments in place
            try:
                em = self.evaluator.eval_exact_match(
                    copy.deepcopy(p_sql), copy.deepcopy(ex['g_sql']))
            except Exception as e:
                em = False
            # if not em:
            #     print(g_str)
            #     print(p_str)
            #     print(ex['final_sql_parse'])
            #     import pdb; pdb.set_trace()
            metrics['official_em'] += em

            if self.args.keep_values:
                g_ex = self.execute(db, g_str, ex['g_sql'])
                p_ex = self.execute(db, p_str, p_sql)
                exe = 0 if p_ex is False else p_ex == g_ex
                metrics['official_ex'] += exe

        metrics['official_em'] /= len(dev)
        metrics['official_ex'] /= len(dev)
        return metrics
Ejemplo n.º 3
0
    def build_sql(self, query, db_id):
        db_path = os.path.join(self.db, db_id, db_id + ".sqlite")
        schema = evaluation.Schema(evaluation.get_schema(db_path))

        try:
            sql = evaluation.get_sql(schema, query)
        except Exception as e:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            sql = evaluation.EMPTY_QUERY.copy()
            print(e)
        valid_col_units = evaluation.build_valid_col_units(
            sql['from']['table_units'], schema)
        sql_val = evaluation.rebuild_sql_val(sql)
        sql_col = evaluation.rebuild_sql_col(valid_col_units, sql_val,
                                             self.kmaps[db_id])
        return sql_col
Ejemplo n.º 4
0
 def match(self, beam, ex):
     pointer = beam.inds
     toks = preprocess.SQLDataset.recover_slots(pointer, ex['cands_query'], eos=self.sql_vocab.word2index('EOS'))
     db_name = ex['db_id']
     db_path = os.path.join('data', 'database', db_name, db_name + ".sqlite")
     g_str = ex['g_query']
     p_str = beam.post
     try:
         schema = evaluation.Schema(evaluation.get_schema(db_path))
         p_sql = preprocess.SQLDataset.build_sql(schema, p_str, self.kmaps[db_name])
         g_sql = preprocess.SQLDataset.build_sql(schema, g_str, self.kmaps[db_name])
         # the offical eval script is buggy and modifies arguments in place
         em = self.evaluator.eval_exact_match(copy.deepcopy(p_sql), copy.deepcopy(g_sql))
     except Exception as e:
         em = False
     beam.query_toks = toks
     beam.em = em
     beam.query = p_str
     beam.g_query = g_str
     if beam.toks == ex['g_query_norm'].split():
         beam.em = True
     return beam
Ejemplo n.º 5
0
    def compute_official_eval(self, dev, dev_preds):
        metrics = dict(official_em=0, official_ex=0)
        for ex in tqdm.tqdm(dev, desc='official eval'):
            p = dev_preds[ex['id']]
            g_str = ex['g_query']
            db_name = ex['db_id']
            db = os.path.join('data', 'database', db_name, db_name + ".sqlite")
            p_str = dev_preds[ex['id']]['query']
            # fix spacing
            spacing = [
                ('` ` ', '"'),
                ("''", '"'),
                ('> =', '>='),
                ('< =', '<='),
                ("'% ", "'%"),
                (" %'", "%'"),
            ]
            for f, t in spacing:
                p_str = p_str.replace(f, t)
            schema = evaluation.Schema(evaluation.get_schema(db))
            p_sql = preprocess.SQLDataset.build_sql(schema, p_str,
                                                    self.kmaps[ex['db_id']])

            # the offical eval script is buggy and modifies arguments in place
            try:
                em = self.evaluator.eval_exact_match(
                    copy.deepcopy(p_sql), copy.deepcopy(ex['g_sql']))
            except Exception as e:
                em = False
            # if not em:
            #     print(g_str)
            #     print(p_str)
            #     print(ex['final_sql_parse'])
            #     import pdb; pdb.set_trace()
            metrics['official_em'] += em
        metrics['official_em'] /= len(dev)
        metrics['official_ex'] /= len(dev)
        return metrics
Ejemplo n.º 6
0
def batch_execute(data,
                  silent=False,
                  sleep=0.1,
                  timeout=5,
                  n_proc=5,
                  desc='batch execute'):
    spacing = [('< =', '<='), ('> =', '<='), ('! =', '!=')]
    proc = []
    for db, query, sql in data:
        for f, t in spacing:
            query_recov = query.replace(f, t)
        proc.append((query, sql, db, silent))

    par = joblib.Parallel(n_proc, backend='threading')
    out = par(
        joblib.delayed(batch_execute_one)(*args)
        for args in tqdm.tqdm(proc, desc=desc))
    return out


if __name__ == '__main__':
    db_id = 'soccer_1'
    ftables = os.path.join('data', 'spider', 'tables.json')

    db_path = 'data/database/flight_4/flight_4.sqlite'
    query = "select T1.airline from routes as T1 join airports as T2 on T1.src_apid = T2.apid where T2.city != 'Imo'"
    kmaps = evaluation.build_foreign_key_map_from_json(ftables)
    schema = evaluation.Schema(evaluation.get_schema(db_path))
    print(timed_execute(db_path, query, timeout=2))