Exemple #1
0
    def __init__(self, paths: List[str], tables_paths: List[str],
                 db_path: str):
        self.paths = paths
        self.db_path = db_path
        self.examples = []

        self.schemas, self.eval_foreign_key_maps = load_tables(tables_paths)
        original_schemas = load_original_schemas(tables_paths)

        for path in paths:
            raw_data = json.load(open(path))
            for entry in raw_data:
                if "sql" not in entry:
                    entry["sql"] = get_sql(original_schemas[entry["db_id"]],
                                           entry["query"])
                item = SpiderItem(
                    question=entry["question"],
                    slml_question=entry.get("slml_question", None),
                    query=entry["query"],
                    spider_sql=entry["sql"],
                    spider_schema=self.schemas[entry["db_id"]],
                    db_path=self.get_db_path(entry["db_id"]),
                    orig=entry,
                )
                self.examples.append(item)
Exemple #2
0
def go_spider_sql(db_id: str, sql: str) -> Optional[dict]:
    tables_fpath = "data/database/{}/tables.json".format(db_id)

    with open(tables_fpath, "r") as f:
        tables_data = json.load(f)

    schemas, db_names, tables = _get_schemas_from_json(tables_data)
    schema = schemas[db_id]
    table = tables[db_id]
    schema = Schema(schema, table)

    sql = re.sub(r"COUNT\(DISTINCT \"(.*)\"\)", r'COUNT(DISTINCT "\1".id)',
                 sql)
    sql = (
        sql.replace('"', "")
        # .replace("(", " ")
        # .replace(")", " ")
        .replace("ILIKE", "="))
    sql = re.sub(
        r"WHERE \((.*)\) AND \((.*)\) AND \((.*)\) AND \((.*)\) AND \((.*)\)",
        r'WHERE \1 AND \2 AND \3 AND \4 AND \5', sql)
    sql = re.sub(r"WHERE \((.*)\) AND \((.*)\) AND \((.*)\) AND \((.*)\)",
                 r'WHERE \1 AND \2 AND \3 AND \4', sql)
    sql = re.sub(r"WHERE \((.*)\) AND \((.*)\) AND \((.*)\)",
                 r'WHERE \1 AND \2 AND \3', sql)
    sql = re.sub(r"WHERE \((.*)\) AND \((.*)\)", r'WHERE \1 AND \2', sql)
    sql = re.sub(r"WHERE \((.*)\)", r'WHERE \1', sql)
    # print(sql)
    pattern: Pattern[str] = re.compile(
        r'(?P<term>(\w|\.|\"|\(|\)|(COUNT\s*)|(\s*DISTINCT\s*))+) AS (?P<alias>(\w|\.)+)'
    )
    aliases: List[dict] = []
    for match in pattern.finditer(sql):
        if match.group('term') is not None and match.group(
                'alias') is not None:
            aliases.append({
                "alias": match.group('alias'),
                "term": match.group('term')
            })
    # print(aliases)
    sql = re.sub(r" AS (\w|\.)+", r"", sql)
    for d in aliases:
        # replace term with a unique string to prevent part of it to be replaced by itself
        unique_string = str(uuid.uuid4())
        while unique_string in sql:
            unique_string = str(uuid.uuid4())
        sql = sql.replace(d["term"], unique_string)
        sql = sql.replace(" " + d["alias"], " " + d["term"])
        sql = sql.replace(unique_string, d["term"])

    try:
        spider_sql = get_sql(schema, sql)
    except Exception as e:
        print(sql, e)
        spider_sql = None

    return spider_sql
def generate_every_db(db, schemas, tables, patterns_data):
    db_name = db["db_id"]
    col_types = db["column_types"]
    process_sql_schema = schema_mod.Schema(schemas[db_name], tables[db_name])

    if "number" in col_types:
        try:
            schema = Schema(db)
        except:
            traceback.print_exc()
            print("skip db {}".format(db_name))
            return

        patterns = load_patterns(patterns_data, schema)
        questions_and_queries = []
        while len(questions_and_queries) < 10:
            pattern = random.choice(patterns)
            try:
                sql, questions = pattern.populate()
                #for q in questions:
                if len(questions) != 0:
                    question = random.choice(questions)
                    questions_and_queries.append((question, sql))
            except:
                pass

        return [{
            "db_id": db_name,
            "query": query,
            "query_toks": process_sql.tokenize(query),
            "query_toks_no_value": None,
            "question": question,
            "question_toks": nltk.word_tokenize(question),
            "sql": process_sql.get_sql(process_sql_schema, query),
        } for question, query in questions_and_queries]
    else:
        return []
    output_file = args.output
    table_file = args.tables

    schemas, db_names, tables = get_schemas_from_json(table_file)

    with open(sql_path) as inf:
        sql_data = json.load(inf)

    sql_data_new = []
    for data in tqdm.tqdm(sql_data):
        try:
            db_id = data["db_id"]
            schema = schemas[db_id]
            table = tables[db_id]
            schema = Schema(schema, table)
            sql = data["query"]
            sql_label = get_sql(schema, sql)
            data["sql"] = sql_label
            sql_data_new.append(data)
        except:
            print("db_id: ", db_id)
            print("sql: ", sql)
            raise

    with open(output_file, 'wt') as out:
        json.dump(sql_data_new,
                  out,
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
    def evaluate_one(self, db_name, gold, predicted):
        schema = self.schemas[db_name]
        g_sql = get_sql(schema, gold)
        hardness = self.eval_hardness(g_sql)
        self.scores[hardness]["count"] += 1
        self.scores["all"]["count"] += 1

        parse_error = False
        try:
            p_sql = get_sql(schema, predicted)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [False, []],
                "union": None,
                "where": [],
            }

            # TODO fix
            parse_error = True

        # rebuild sql for value evaluation
        kmap = self.kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql["from"]["table_units"],
                                                  schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql["from"]["table_units"],
                                                  schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        if self.etype in ["all", "exec"]:
            self.scores[hardness]["exec"] += eval_exec_match(
                self.db_paths[db_name], predicted, gold, p_sql, g_sql)

        if self.etype in ["all", "match"]:
            partial_scores = self.eval_partial_match(p_sql, g_sql)
            exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores)
            self.scores[hardness]["exact"] += exact_score
            self.scores["all"]["exact"] += exact_score
            for type_ in PARTIAL_TYPES:
                if partial_scores[type_]["pred_total"] > 0:
                    self.scores[hardness]["partial"][type_][
                        "acc"] += partial_scores[type_]["acc"]
                    self.scores[hardness]["partial"][type_]["acc_count"] += 1
                if partial_scores[type_]["label_total"] > 0:
                    self.scores[hardness]["partial"][type_][
                        "rec"] += partial_scores[type_]["rec"]
                    self.scores[hardness]["partial"][type_]["rec_count"] += 1
                self.scores[hardness]["partial"][type_][
                    "f1"] += partial_scores[type_]["f1"]
                if partial_scores[type_]["pred_total"] > 0:
                    self.scores["all"]["partial"][type_][
                        "acc"] += partial_scores[type_]["acc"]
                    self.scores["all"]["partial"][type_]["acc_count"] += 1
                if partial_scores[type_]["label_total"] > 0:
                    self.scores["all"]["partial"][type_][
                        "rec"] += partial_scores[type_]["rec"]
                    self.scores["all"]["partial"][type_]["rec_count"] += 1
                self.scores["all"]["partial"][type_]["f1"] += partial_scores[
                    type_]["f1"]

        return {
            "predicted": predicted,
            "gold": gold,
            "predicted_parse_error": parse_error,
            "hardness": hardness,
            "exact": exact_score,
            "partial": partial_scores,
        }
Exemple #6
0
    def evaluate_one(self, db_name, gold, predicted):
        schema = self.schemas[db_name]
        g_sql = get_sql(schema, gold)
        hardness = self.eval_hardness(g_sql)
        self.scores[hardness]['count'] += 1
        self.scores['all']['count'] += 1

        parse_error = False
        try:
            p_sql = get_sql(schema, predicted)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [False, []],
                "union": None,
                "where": []
            }

            # TODO fix
            parse_error = True

        # rebuild sql for value evaluation
        kmap = self.kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'],
                                                  schema)
        g_sql_with_val = copy.deepcopy(
            rebuild_sql_col(g_valid_col_units, g_sql, kmap))
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)

        try:
            p_valid_col_units = build_valid_col_units(
                p_sql['from']['table_units'], schema)
            p_sql_with_val = copy.deepcopy(
                rebuild_sql_col(g_valid_col_units, p_sql, kmap))
        except:
            p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [False, []],
                "union": None,
                "where": []
            }

            # TODO fix
            parse_error = True
            p_valid_col_units = build_valid_col_units(
                p_sql['from']['table_units'], schema)
            p_sql_with_val = copy.deepcopy(
                rebuild_sql_col(g_valid_col_units, p_sql, kmap))
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
        exec_score = 0
        # self.etype = 'match'
        if self.etype in ["all", "exec"]:
            is_error = False
            is_empty = False
            try:
                if db_name not in self.cached_db:
                    with sqlite3.connect(self.db_paths[db_name]) as source:
                        dest = sqlite3.connect(':memory:')
                        source.backup(dest)
                        conn = dest
                        conn.text_factory = lambda x: str(x, 'latin1')
                        self.cached_db[db_name] = conn.cursor()
                exec_score, is_empty, is_error = eval_exec_match(
                    self.cached_db[db_name], predicted, gold, p_sql, g_sql)
            except timeout_decorator.timeout_decorator.TimeoutError:
                exec_score = False
                is_error = True
            except (sqlite3.Warning, sqlite3.Error, sqlite3.DatabaseError,
                    sqlite3.IntegrityError, sqlite3.ProgrammingError,
                    sqlite3.OperationalError, sqlite3.NotSupportedError) as e:
                exec_score = False
                is_error = True
            self.scores[hardness]['exec'] += exec_score
            self.scores['all']['exec'] += exec_score
            if not is_empty:
                self.scores[hardness]['exec (non empty)'].append(exec_score)
                self.scores['all']['exec (non empty)'].append(exec_score)
            self.scores[hardness]['is_empty'] += is_empty
            self.scores['all']['is_empty'] += is_empty
            self.scores[hardness]['is_error'] += is_error
            self.scores['all']['is_error'] += is_error

        if self.etype in ["all", "match"]:
            partial_scores = self.eval_partial_match(p_sql, g_sql,
                                                     p_sql_with_val,
                                                     g_sql_with_val)
            exact_score = self.eval_exact_match(p_sql, g_sql, {
                k: v
                for k, v in partial_scores.items() if 'with value' not in k
            })
            exact_score_with_val = self.eval_exact_match(
                p_sql, g_sql, partial_scores)
            self.scores[hardness]['exact'] += exact_score
            self.scores['all']['exact'] += exact_score
            self.scores[hardness]['exact (with val)'] += exact_score_with_val
            self.scores['all']['exact (with val)'] += exact_score_with_val
            for type_ in PARTIAL_TYPES:
                if partial_scores[type_]['pred_total'] > 0:
                    self.scores[hardness]['partial'][type_][
                        'acc'] += partial_scores[type_]['acc']
                    self.scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    self.scores[hardness]['partial'][type_][
                        'rec'] += partial_scores[type_]['rec']
                    self.scores[hardness]['partial'][type_]['rec_count'] += 1
                self.scores[hardness]['partial'][type_][
                    'f1'] += partial_scores[type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    self.scores['all']['partial'][type_][
                        'acc'] += partial_scores[type_]['acc']
                    self.scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    self.scores['all']['partial'][type_][
                        'rec'] += partial_scores[type_]['rec']
                    self.scores['all']['partial'][type_]['rec_count'] += 1
                self.scores['all']['partial'][type_]['f1'] += partial_scores[
                    type_]['f1']
        # if self.etype == 'all' and exact_score == 1 and exec_score != 1:
        # pdb.set_trace()
        # exec_score = eval_exec_match(self.db_paths[db_name], predicted, gold, p_sql, g_sql)

        return {
            'predicted': predicted,
            'gold': gold,
            'predicted_parse_error': parse_error,
            'hardness': hardness,
            'exact': exact_score,
            'exact (with val)': exact_score_with_val,
            'partial': partial_scores,
            'exec': exec_score,
            'is_empty': is_empty,
            'is_error': is_error
        }
Exemple #7
0
    def evaluate_one(self, db_name, gold, predicted):
        schema = self.schemas[db_name]
        g_sql = get_sql(schema, gold)
        hardness = self.eval_hardness(g_sql)
        self.scores[hardness]['count'] += 1
        self.scores['all']['count'] += 1

        parse_error = False
        try:
            p_sql = get_sql(schema, predicted)
        except:
            # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
            p_sql = {
                "except": None,
                "from": {
                    "conds": [],
                    "table_units": []
                },
                "groupBy": [],
                "having": [],
                "intersect": None,
                "limit": None,
                "orderBy": [],
                "select": [False, []],
                "union": None,
                "where": []
            }

            # TODO fix
            parse_error = True

        # rebuild sql for value evaluation
        kmap = self.kmaps[db_name]
        g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'],
                                                  schema)
        g_sql = rebuild_sql_val(g_sql)
        g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
        p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'],
                                                  schema)
        p_sql = rebuild_sql_val(p_sql)
        p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)

        if self.etype in ["all", "exec"]:
            # added by bailin
            exe_flag = eval_exec_match(self.db_paths[db_name], predicted, gold,
                                       p_sql, g_sql)
            if exe_flag:
                self.scores[hardness]['exec'] += 1.0
                self.scores['all']['exec'] += 1.0
        else:
            exe_flag = None

        if self.etype in ["all", "match"]:
            partial_scores = self.eval_partial_match(p_sql, g_sql)
            exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores)
            self.scores[hardness]['exact'] += exact_score
            self.scores['all']['exact'] += exact_score
            for type_ in PARTIAL_TYPES:
                if partial_scores[type_]['pred_total'] > 0:
                    self.scores[hardness]['partial'][type_][
                        'acc'] += partial_scores[type_]['acc']
                    self.scores[hardness]['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    self.scores[hardness]['partial'][type_][
                        'rec'] += partial_scores[type_]['rec']
                    self.scores[hardness]['partial'][type_]['rec_count'] += 1
                self.scores[hardness]['partial'][type_][
                    'f1'] += partial_scores[type_]['f1']
                if partial_scores[type_]['pred_total'] > 0:
                    self.scores['all']['partial'][type_][
                        'acc'] += partial_scores[type_]['acc']
                    self.scores['all']['partial'][type_]['acc_count'] += 1
                if partial_scores[type_]['label_total'] > 0:
                    self.scores['all']['partial'][type_][
                        'rec'] += partial_scores[type_]['rec']
                    self.scores['all']['partial'][type_]['rec_count'] += 1
                self.scores['all']['partial'][type_]['f1'] += partial_scores[
                    type_]['f1']
        else:
            exact_score = None
            partial_scores = None

        return {
            'predicted': predicted,
            'gold': gold,
            'execution': exe_flag,
            'predicted_parse_error': parse_error,
            'hardness': hardness,
            'exact': exact_score,
            'partial': partial_scores
        }