Ejemplo n.º 1
0
 def __init__(self, dataset):
     self.dataset = dataset
     self.foreign_key_maps = {
         db_id: evaluation.build_foreign_key_map(schema.orig)
         for db_id, schema in self.dataset.schemas.items()
     }
     self.evaluator = evaluation.Evaluator(self.dataset.db_path,
                                           self.foreign_key_maps, 'all')
     self.results = []
Ejemplo n.º 2
0
def load_tables(paths):
    schemas = {}
    eval_foreign_key_maps = {}

    for path in paths:
        schema_dicts = json.load(open(path))
        for schema_dict in schema_dicts:
            db_id = schema_dict["db_id"]
            assert db_id not in schemas
            schemas[db_id] = schema_dict_to_spider_schema(schema_dict)
            eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(
                schema_dict)

    return schemas, eval_foreign_key_maps
Ejemplo n.º 3
0
def load_tables(paths):
    schemas = {}
    eval_foreign_key_maps = {}

    with open(paths, 'r', encoding='UTF-8') as f:
        schema_dicts = json.load(f)

    for schema_dict in schema_dicts:
        db_id = schema_dict["db_id"]
        if 'column_names_original' not in schema_dict:  # {'table': [col.lower, ..., ]} * -> __all__
            # continue
            schema_dict["column_names_original"] = schema_dict["column_names"]
            schema_dict["table_names_original"] = schema_dict["table_names"]

        # assert db_id not in schemas
        schemas[db_id] = schema_dict_to_spider_schema(schema_dict)
        eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(
            schema_dict)

    return schemas, eval_foreign_key_maps
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    # Outputs of infer.py
    parser.add_argument('--infer', nargs='*', default=())
    # Files containing inferred SQL, one per line
    # in order of the items in the dev set.
    parser.add_argument('--sql', nargs='*', default=())
    # The name to output for each of the inputs, in the CSV header.
    parser.add_argument('--names', nargs='*', default=())
    parser.add_argument('--out', required=True)
    args = parser.parse_args()
    assert len(args.names) == len(args.infer) + len(args.sql)

    SPIDER_ROOT = 'data/spider-20190205'
    foreign_key_maps = {
        db['db_id']: evaluation.build_foreign_key_map(db)
        for db in json.load(open(os.path.join(SPIDER_ROOT, 'tables.json')))
    }

    # 1. Create the evaluator
    evaluator = evaluation.Evaluator(os.path.join(SPIDER_ROOT, 'database'),
                                     foreign_key_maps, 'match')

    # 2. Read the ground truth SQL
    dev = json.load(open(os.path.join(SPIDER_ROOT, 'dev.json')))

    # 3. Perform evaluation
    difficulty = {}
    inferred_per_file = []
    correct_per_file = []
    # db
    # gold

    for infer_path in args.infer:
        inferred = [None] * len(dev)
        correct = [None] * len(dev)
        inferred_per_file.append(inferred)
        correct_per_file.append(correct)

        for line in open(infer_path):
            item = json.loads(line)
            item_inferred = item['beams'][0]['inferred_code']
            i = item['index']

            eval_result = evaluator.evaluate_one(db_name=dev[i]['db_id'],
                                                 gold=dev[i]['query'],
                                                 predicted=item_inferred)

            difficulty[i] = eval_result['hardness']
            inferred[i] = item_inferred
            correct[i] = 1 if eval_result['exact'] else 0

    for sql_path in args.sql:
        inferred = [None] * len(dev)
        correct = [None] * len(dev)
        inferred_per_file.append(inferred)
        correct_per_file.append(correct)

        for i, line in enumerate(open(sql_path)):
            eval_result = evaluator.evaluate_one(db_name=dev[i]['db_id'],
                                                 gold=dev[i]['query'],
                                                 predicted=line.strip())

            difficulty[i] = eval_result['hardness']
            inferred[i] = line.strip()
            correct[i] = 1 if eval_result['exact'] else 0

    with open(args.out, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['DB', 'Difficulty', 'Question', 'Gold'] +
                        ['{} correct'.format(c) for c in args.names] +
                        ['{} output'.format(c) for c in args.names])

        for i, dev_item in enumerate(dev):
            writer.writerow([
                dev_item['db_id'], difficulty[i], dev_item['question'],
                dev_item['query']
            ] + [x[i] for x in correct_per_file] +
                            [x[i] for x in inferred_per_file])
Ejemplo n.º 5
0
def load_tables(paths):
    schemas = {}
    eval_foreign_key_maps = {}

    for path in paths:
        schema_dicts = json.load(open(path))
        for schema_dict in schema_dicts:
            tables = tuple(
                Table(
                    id=i,
                    name=name.split(),
                    unsplit_name=name,
                    orig_name=orig_name,
                ) for i, (name, orig_name) in enumerate(
                    zip(schema_dict['table_names'],
                        schema_dict['table_names_original'])))
            columns = tuple(
                Column(
                    id=i,
                    table=tables[table_id] if table_id >= 0 else None,
                    name=col_name.split(),
                    unsplit_name=col_name,
                    orig_name=orig_col_name,
                    type=col_type,
                ) for i, ((table_id, col_name), (_, orig_col_name),
                          col_type) in enumerate(
                              zip(schema_dict['column_names'],
                                  schema_dict['column_names_original'],
                                  schema_dict['column_types'])))

            # Link columns to tables
            for column in columns:
                if column.table:
                    column.table.columns.append(column)

            for column_id in schema_dict['primary_keys']:
                # Register primary keys
                column = columns[column_id]
                column.table.primary_keys.append(column)

            foreign_key_graph = nx.DiGraph()
            for source_column_id, dest_column_id in schema_dict[
                    'foreign_keys']:
                # Register foreign keys
                source_column = columns[source_column_id]
                dest_column = columns[dest_column_id]
                source_column.foreign_key_for = dest_column
                foreign_key_graph.add_edge(source_column.table.id,
                                           dest_column.table.id,
                                           columns=(source_column_id,
                                                    dest_column_id))
                foreign_key_graph.add_edge(dest_column.table.id,
                                           source_column.table.id,
                                           columns=(dest_column_id,
                                                    source_column_id))

            db_id = schema_dict['db_id']
            assert db_id not in schemas
            schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph,
                                    schema_dict)
            eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(
                schema_dict)

    return schemas, eval_foreign_key_maps
Ejemplo n.º 6
0
def load_tables(paths, with_value=False):
    schemas = {}
    eval_foreign_key_maps = {}

    for path in paths:
        schema_dicts = json.load(open(path))
        if with_value:
            try:
                cached_value_path = path.replace('.json',
                                                 '_column_values.json')
                with open(cached_value_path, 'r') as f:
                    cached_values = json.load(f)
                print('loaded cached column value from %s' % cached_value_path)
            except FileNotFoundError:
                print('cached value file not found')
                cached_values = {}

        for schema_dict in schema_dicts:
            if with_value:
                cached_value = cached_values.get(schema_dict['db_id'],
                                                 [[], {}, []])
            tables = tuple(
                Table(
                    id=i,
                    name=name.split(),
                    unsplit_name=name,
                    orig_name=orig_name,
                ) for i, (name, orig_name) in enumerate(
                    zip(schema_dict['table_names'],
                        schema_dict['table_names_original'])))
            if with_value and len(cached_value[0]) != 0:
                columns = tuple(
                    Column(id=i,
                           table=tables[table_id] if table_id >= 0 else None,
                           name=col_name.split(),
                           unsplit_name=col_name,
                           orig_name=orig_col_name,
                           type=col_type,
                           value_range=value_range,
                           tokens=set(tokens),
                           cell_values=set(cell_values),
                           value_vocab_ids=value_vocab_ids,
                           value_vocab_weights=value_vocab_weights) for i,
                    ((table_id, col_name), (_, orig_col_name), col_type,
                     (value_range, tokens, cell_values, value_vocab_ids,
                      value_vocab_weights)) in enumerate(
                          zip(schema_dict['column_names'],
                              schema_dict['column_names_original'],
                              schema_dict['column_types'], cached_value[0])))
            else:
                columns = tuple(
                    Column(
                        id=i,
                        table=tables[table_id] if table_id >= 0 else None,
                        name=col_name.split(),
                        unsplit_name=col_name,
                        orig_name=orig_col_name,
                        type=col_type,
                    ) for i, ((table_id, col_name), (_, orig_col_name),
                              col_type) in enumerate(
                                  zip(schema_dict['column_names'],
                                      schema_dict['column_names_original'],
                                      schema_dict['column_types'])))

            # Link columns to tables
            for column in columns:
                if column.table:
                    column.table.columns.append(column)

            for column_id in schema_dict['primary_keys']:
                # Register primary keys
                column = columns[column_id]
                column.table.primary_keys.append(column)

            foreign_key_graph = nx.DiGraph()
            for source_column_id, dest_column_id in schema_dict[
                    'foreign_keys']:
                # Register foreign keys
                source_column = columns[source_column_id]
                dest_column = columns[dest_column_id]
                source_column.foreign_key_for = dest_column
                foreign_key_graph.add_edge(source_column.table.id,
                                           dest_column.table.id,
                                           columns=(source_column_id,
                                                    dest_column_id))
                foreign_key_graph.add_edge(dest_column.table.id,
                                           source_column.table.id,
                                           columns=(dest_column_id,
                                                    source_column_id))

            # HACK: Introduce column synonyms as "phantom" columns
            if 'column_synonyms_original' in schema_dict:
                synonym_columns = []
                col_id = len(columns)
                for orig_name, synonyms in schema_dict[
                        'column_synonyms_original'].items():
                    orig_column = next(
                        (c for c in columns if c.orig_name == orig_name), None)
                    if not orig_column: continue
                    for syn_name in synonyms:
                        syn_column = copy(orig_column)
                        syn_column.synonym_for = orig_column
                        syn_column.orig_name = syn_name
                        syn_column.unsplit_name = postprocess_original_name(
                            syn_name)
                        syn_column.name = syn_column.unsplit_name.split()
                        syn_column.id = col_id
                        col_id += 1
                        syn_column.table.columns.append(syn_column)
                        if orig_column.foreign_key_for is not None:
                            foreign_key_graph.add_edge(
                                orig_column.table.id,
                                orig_column.foreign_key_for.table.id,
                                columns=(syn_column.id,
                                         orig_column.foreign_key_for.id))
                            foreign_key_graph.add_edge(
                                orig_column.foreign_key_for.table.id,
                                orig_column.table.id,
                                columns=(orig_column.foreign_key_for.id,
                                         syn_column.id))
                        synonym_columns.append(syn_column)
                columns = tuple(list(columns) + synonym_columns)

            db_id = schema_dict['db_id']
            hidden = schema_dict.get('hidden', False)
            assert db_id not in schemas
            if with_value:
                schemas[db_id] = Schema(
                    db_id,
                    tables,
                    columns,
                    foreign_key_graph,
                    schema_dict,
                    hidden,
                    nxt_masks={int(k): v
                               for k, v in cached_value[1].items()},
                    col_value_maps=cached_value[2])
            else:
                schemas[db_id] = Schema(db_id, tables, columns,
                                        foreign_key_graph, schema_dict, hidden)
            eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(
                schema_dict)

    return schemas, eval_foreign_key_maps