def __init__(self, dataset): self.dataset = dataset self.foreign_key_maps = { db_id: evaluation.build_foreign_key_map(schema.orig) for db_id, schema in self.dataset.schemas.items() } self.evaluator = evaluation.Evaluator(self.dataset.db_path, self.foreign_key_maps, 'all') self.results = []
def load_tables(paths): schemas = {} eval_foreign_key_maps = {} for path in paths: schema_dicts = json.load(open(path)) for schema_dict in schema_dicts: db_id = schema_dict["db_id"] assert db_id not in schemas schemas[db_id] = schema_dict_to_spider_schema(schema_dict) eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map( schema_dict) return schemas, eval_foreign_key_maps
def load_tables(paths): schemas = {} eval_foreign_key_maps = {} with open(paths, 'r', encoding='UTF-8') as f: schema_dicts = json.load(f) for schema_dict in schema_dicts: db_id = schema_dict["db_id"] if 'column_names_original' not in schema_dict: # {'table': [col.lower, ..., ]} * -> __all__ # continue schema_dict["column_names_original"] = schema_dict["column_names"] schema_dict["table_names_original"] = schema_dict["table_names"] # assert db_id not in schemas schemas[db_id] = schema_dict_to_spider_schema(schema_dict) eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map( schema_dict) return schemas, eval_foreign_key_maps
def main(): parser = argparse.ArgumentParser() # Outputs of infer.py parser.add_argument('--infer', nargs='*', default=()) # Files containing inferred SQL, one per line # in order of the items in the dev set. parser.add_argument('--sql', nargs='*', default=()) # The name to output for each of the inputs, in the CSV header. parser.add_argument('--names', nargs='*', default=()) parser.add_argument('--out', required=True) args = parser.parse_args() assert len(args.names) == len(args.infer) + len(args.sql) SPIDER_ROOT = 'data/spider-20190205' foreign_key_maps = { db['db_id']: evaluation.build_foreign_key_map(db) for db in json.load(open(os.path.join(SPIDER_ROOT, 'tables.json'))) } # 1. Create the evaluator evaluator = evaluation.Evaluator(os.path.join(SPIDER_ROOT, 'database'), foreign_key_maps, 'match') # 2. Read the ground truth SQL dev = json.load(open(os.path.join(SPIDER_ROOT, 'dev.json'))) # 3. Perform evaluation difficulty = {} inferred_per_file = [] correct_per_file = [] # db # gold for infer_path in args.infer: inferred = [None] * len(dev) correct = [None] * len(dev) inferred_per_file.append(inferred) correct_per_file.append(correct) for line in open(infer_path): item = json.loads(line) item_inferred = item['beams'][0]['inferred_code'] i = item['index'] eval_result = evaluator.evaluate_one(db_name=dev[i]['db_id'], gold=dev[i]['query'], predicted=item_inferred) difficulty[i] = eval_result['hardness'] inferred[i] = item_inferred correct[i] = 1 if eval_result['exact'] else 0 for sql_path in args.sql: inferred = [None] * len(dev) correct = [None] * len(dev) inferred_per_file.append(inferred) correct_per_file.append(correct) for i, line in enumerate(open(sql_path)): eval_result = evaluator.evaluate_one(db_name=dev[i]['db_id'], gold=dev[i]['query'], predicted=line.strip()) difficulty[i] = eval_result['hardness'] inferred[i] = line.strip() correct[i] = 1 if eval_result['exact'] else 0 with open(args.out, 'w') as f: writer = csv.writer(f) writer.writerow(['DB', 'Difficulty', 'Question', 'Gold'] + ['{} correct'.format(c) for c in args.names] + ['{} output'.format(c) for c in args.names]) for i, dev_item in enumerate(dev): writer.writerow([ dev_item['db_id'], difficulty[i], dev_item['question'], dev_item['query'] ] + [x[i] for x in correct_per_file] + [x[i] for x in inferred_per_file])
def load_tables(paths): schemas = {} eval_foreign_key_maps = {} for path in paths: schema_dicts = json.load(open(path)) for schema_dict in schema_dicts: tables = tuple( Table( id=i, name=name.split(), unsplit_name=name, orig_name=orig_name, ) for i, (name, orig_name) in enumerate( zip(schema_dict['table_names'], schema_dict['table_names_original']))) columns = tuple( Column( id=i, table=tables[table_id] if table_id >= 0 else None, name=col_name.split(), unsplit_name=col_name, orig_name=orig_col_name, type=col_type, ) for i, ((table_id, col_name), (_, orig_col_name), col_type) in enumerate( zip(schema_dict['column_names'], schema_dict['column_names_original'], schema_dict['column_types']))) # Link columns to tables for column in columns: if column.table: column.table.columns.append(column) for column_id in schema_dict['primary_keys']: # Register primary keys column = columns[column_id] column.table.primary_keys.append(column) foreign_key_graph = nx.DiGraph() for source_column_id, dest_column_id in schema_dict[ 'foreign_keys']: # Register foreign keys source_column = columns[source_column_id] dest_column = columns[dest_column_id] source_column.foreign_key_for = dest_column foreign_key_graph.add_edge(source_column.table.id, dest_column.table.id, columns=(source_column_id, dest_column_id)) foreign_key_graph.add_edge(dest_column.table.id, source_column.table.id, columns=(dest_column_id, source_column_id)) db_id = schema_dict['db_id'] assert db_id not in schemas schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph, schema_dict) eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map( schema_dict) return schemas, eval_foreign_key_maps
def load_tables(paths, with_value=False): schemas = {} eval_foreign_key_maps = {} for path in paths: schema_dicts = json.load(open(path)) if with_value: try: cached_value_path = path.replace('.json', '_column_values.json') with open(cached_value_path, 'r') as f: cached_values = json.load(f) print('loaded cached column value from %s' % cached_value_path) except FileNotFoundError: print('cached value file not found') cached_values = {} for schema_dict in schema_dicts: if with_value: cached_value = cached_values.get(schema_dict['db_id'], [[], {}, []]) tables = tuple( Table( id=i, name=name.split(), unsplit_name=name, orig_name=orig_name, ) for i, (name, orig_name) in enumerate( zip(schema_dict['table_names'], schema_dict['table_names_original']))) if with_value and len(cached_value[0]) != 0: columns = tuple( Column(id=i, table=tables[table_id] if table_id >= 0 else None, name=col_name.split(), unsplit_name=col_name, orig_name=orig_col_name, type=col_type, value_range=value_range, tokens=set(tokens), cell_values=set(cell_values), value_vocab_ids=value_vocab_ids, value_vocab_weights=value_vocab_weights) for i, ((table_id, col_name), (_, orig_col_name), col_type, (value_range, tokens, cell_values, value_vocab_ids, value_vocab_weights)) in enumerate( zip(schema_dict['column_names'], schema_dict['column_names_original'], schema_dict['column_types'], cached_value[0]))) else: columns = tuple( Column( id=i, table=tables[table_id] if table_id >= 0 else None, name=col_name.split(), unsplit_name=col_name, orig_name=orig_col_name, type=col_type, ) for i, ((table_id, col_name), (_, orig_col_name), col_type) in enumerate( zip(schema_dict['column_names'], schema_dict['column_names_original'], schema_dict['column_types']))) # Link columns to tables for column in columns: if column.table: column.table.columns.append(column) for column_id in schema_dict['primary_keys']: # Register primary keys column = columns[column_id] column.table.primary_keys.append(column) foreign_key_graph = nx.DiGraph() for source_column_id, dest_column_id in schema_dict[ 'foreign_keys']: # Register foreign keys source_column = columns[source_column_id] dest_column = columns[dest_column_id] source_column.foreign_key_for = dest_column foreign_key_graph.add_edge(source_column.table.id, dest_column.table.id, columns=(source_column_id, dest_column_id)) foreign_key_graph.add_edge(dest_column.table.id, source_column.table.id, columns=(dest_column_id, source_column_id)) # HACK: Introduce column synonyms as "phantom" columns if 'column_synonyms_original' in schema_dict: synonym_columns = [] col_id = len(columns) for orig_name, synonyms in schema_dict[ 'column_synonyms_original'].items(): orig_column = next( (c for c in columns if c.orig_name == orig_name), None) if not orig_column: continue for syn_name in synonyms: syn_column = copy(orig_column) syn_column.synonym_for = orig_column syn_column.orig_name = syn_name syn_column.unsplit_name = postprocess_original_name( syn_name) syn_column.name = syn_column.unsplit_name.split() syn_column.id = col_id col_id += 1 syn_column.table.columns.append(syn_column) if orig_column.foreign_key_for is not None: foreign_key_graph.add_edge( orig_column.table.id, orig_column.foreign_key_for.table.id, columns=(syn_column.id, orig_column.foreign_key_for.id)) foreign_key_graph.add_edge( orig_column.foreign_key_for.table.id, orig_column.table.id, columns=(orig_column.foreign_key_for.id, syn_column.id)) synonym_columns.append(syn_column) columns = tuple(list(columns) + synonym_columns) db_id = schema_dict['db_id'] hidden = schema_dict.get('hidden', False) assert db_id not in schemas if with_value: schemas[db_id] = Schema( db_id, tables, columns, foreign_key_graph, schema_dict, hidden, nxt_masks={int(k): v for k, v in cached_value[1].items()}, col_value_maps=cached_value[2]) else: schemas[db_id] = Schema(db_id, tables, columns, foreign_key_graph, schema_dict, hidden) eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map( schema_dict) return schemas, eval_foreign_key_maps