Ejemplo n.º 1
0
def load_schema_graphs_ask_data(data_dir):
    datasets = [
        'airbnb_san_francisco', 'airbnb_seattle', 'sports_salaries', 'wines'
    ]

    schema_graphs = SchemaGraphs()
    for dataset in datasets:
        in_csv = os.path.join(data_dir, 'raw/{}.csv'.format(dataset))
        schema_graph = SchemaGraph(dataset)
        schema_graph.load_data_from_csv_file(in_csv)
        schema_graphs.index_schema_graph(schema_graph)
        schema_graph.pretty_print()
    print('{} schema graphs loaded'.format(schema_graphs.size))
Ejemplo n.º 2
0
def load_schema_graphs_spider(data_dir,
                              dataset_name,
                              db_dir=None,
                              augment_with_wikisql=False):
    """
    Load indexed database schema.
    """
    in_json = os.path.join(data_dir, 'tables.json')
    schema_graphs = SchemaGraphs()

    with open(in_json) as f:
        content = json.load(f)
        for db_content in content:
            db_id = db_content['db_id']
            if dataset_name == 'spider':
                db_path = os.path.join(
                    db_dir, db_id,
                    '{}.sqlite'.format(db_id)) if db_dir else None
            else:
                db_id_parts = db_id.rsplit('_', 1)
                if len(db_id_parts) > 1:
                    m_suffix_pattern = re.compile('m\d+')
                    m_suffix = db_id_parts[1]
                    if re.fullmatch(m_suffix_pattern, m_suffix):
                        db_base_id = db_id_parts[0]
                    else:
                        db_base_id = db_id
                else:
                    db_base_id = db_id_parts[0]
                db_path = os.path.join(
                    db_dir, db_base_id,
                    '{}.sqlite'.format(db_base_id)) if db_dir else None
            schema_graph = SchemaGraph(db_id, db_path)
            if db_dir is not None:
                schema_graph.compute_field_picklist()
            schema_graph.load_data_from_spider_json(db_content)
            schema_graphs.index_schema_graph(schema_graph)
        print('{} schema graphs loaded'.format(schema_graphs.size))

    if augment_with_wikisql:
        parent_dir = os.path.dirname(data_dir)
        wikisql_dir = os.path.join(parent_dir, 'wikisql1.1')
        wikisql_schema_graphs = load_schema_graphs_wikisql(wikisql_dir)
        for db_id in range(wikisql_schema_graphs.size):
            schema_graph = wikisql_schema_graphs.get_schema(db_id)
            schema_graphs.index_schema_graph(schema_graph)
        print('{} schema graphs loaded (+wikisql)'.format(schema_graphs.size))

    return schema_graphs
Ejemplo n.º 3
0
def load_schema_graphs(args):
    """
    Load database schema as a graph.
    """
    dataset_name = args.dataset_name
    if dataset_name in ['spider', 'spider_ut']:
        return load_schema_graphs_spider(
            args.data_dir,
            dataset_name,
            db_dir=args.db_dir,
            augment_with_wikisql=args.augment_with_wikisql)
    if dataset_name == 'wikisql':
        return load_schema_graphs_wikisql(args.data_dir)

    in_csv = os.path.join(args.data_dir, '{}-schema.csv'.format(dataset_name))
    schema_graphs = SchemaGraphs()
    schema_graph = SchemaGraph(dataset_name)
    schema_graph.load_data_from_finegan_dollak_csv_file(in_csv)
    schema_graphs.index_schema_graph(schema_graph)
    return schema_graphs
Ejemplo n.º 4
0
def load_schema_graphs_wikisql(data_dir, splits=['train', 'dev', 'test']):
    schema_graphs = SchemaGraphs()

    for split in splits:
        in_jsonl = os.path.join(data_dir, '{}.tables.jsonl'.format(split))
        db_count = 0
        with open(in_jsonl) as f:
            for line in f:
                table = json.loads(line.strip())
                db_name = table['id']
                schema_graph = WikiSQLSchemaGraph(db_name,
                                                  table,
                                                  caseless=False)
                schema_graph.id = table['id']
                schema_graph.load_data_from_wikisql_json(table)
                schema_graph.compute_field_picklist(table)
                schema_graphs.index_schema_graph(schema_graph)
                db_count += 1
        print('{} databases in {}'.format(db_count, split))
    print('{} databases loaded in total'.format(schema_graphs.size))

    return schema_graphs