def demo(args):
    """
    Interactive command line demo.

    Specify a target database from the Spider dataset and query the database using natural language.
    The output includes:
        1. if the input question is translated to the SQL query, return the SQL query
        2. otherwise, return a confusion span in the question that caused the input to be untranslatable.
    """
    data_dir = 'data/'
    if args.demo_db is None:
        print('Error: must specify a database name to proceed')
        return
    else:
        db_name = args.demo_db
    db_path = os.path.join(args.db_dir, db_name, '{}.sqlite'.format(db_name))
    schema = SchemaGraph(db_name, db_path=db_path)
    if db_name == 'covid_19':
        in_csv = os.path.join(data_dir, db_name, '{}.csv'.format(db_name))
        in_type = os.path.join(data_dir, db_name, '{}.types'.format(db_name))
        schema.load_data_from_csv_file(in_csv, in_type)
    else:
        # TODO: currently the demo is configured for the Spider dataset.
        import json
        in_json = os.path.join(args.data_dir, 'tables.json')
        with open(in_json) as f:
            tables = json.load(f)
        for table in tables:
            if table['db_id'] == db_name:
                break
        schema.load_data_from_spider_json(table)
    schema.pretty_print()

    if args.ensemble_inference:
        t2sql = Text2SQLWrapper(args,
                                cs_args,
                                schema,
                                ensemble_model_dirs=ensemble_model_dirs)
    else:
        t2sql = Text2SQLWrapper(args, cs_args, schema)

    sys.stdout.write('Enter a natural language question: ')
    sys.stdout.write('> ')
    sys.stdout.flush()
    text = sys.stdin.readline()

    while text:
        output = t2sql.process(text, schema.name)
        translatable = output['translatable']
        sql_query = output['sql_query']
        confusion_span = output['confuse_span']
        replacement_span = output['replace_span']
        print('Translatable: {}'.format(translatable))
        print('SQL: {}'.format(sql_query))
        print('Confusion span: {}'.format(confusion_span))
        print('Replacement span: {}'.format(replacement_span))
        sys.stdout.flush()
        sys.stdout.write('\nEnter a natural language question: ')
        sys.stdout.write('> ')
        text = sys.stdin.readline()
def load_schema_graphs(args):
    """
    Load database schema as a graph.
    """
    dataset_name = args.dataset_name
    if dataset_name in ['spider', 'spider_ut']:
        return load_schema_graphs_spider(
            args.data_dir,
            dataset_name,
            db_dir=args.db_dir,
            augment_with_wikisql=args.augment_with_wikisql)
    if dataset_name == 'wikisql':
        return load_schema_graphs_wikisql(args.data_dir)

    in_csv = os.path.join(args.data_dir, '{}-schema.csv'.format(dataset_name))
    schema_graphs = SchemaGraphs()
    schema_graph = SchemaGraph(dataset_name)
    schema_graph.load_data_from_finegan_dollak_csv_file(in_csv)
    schema_graphs.index_schema_graph(schema_graph)
    return schema_graphs
def load_schema_graphs_ask_data(data_dir):
    datasets = [
        'airbnb_san_francisco', 'airbnb_seattle', 'sports_salaries', 'wines'
    ]

    schema_graphs = SchemaGraphs()
    for dataset in datasets:
        in_csv = os.path.join(data_dir, 'raw/{}.csv'.format(dataset))
        schema_graph = SchemaGraph(dataset)
        schema_graph.load_data_from_csv_file(in_csv)
        schema_graphs.index_schema_graph(schema_graph)
        schema_graph.pretty_print()
    print('{} schema graphs loaded'.format(schema_graphs.size))
def load_schema_graphs_spider(data_dir,
                              dataset_name,
                              db_dir=None,
                              augment_with_wikisql=False):
    """
    Load indexed database schema.
    """
    in_json = os.path.join(data_dir, 'tables.json')
    schema_graphs = SchemaGraphs()

    with open(in_json) as f:
        content = json.load(f)
        for db_content in content:
            db_id = db_content['db_id']
            if dataset_name == 'spider':
                db_path = os.path.join(
                    db_dir, db_id,
                    '{}.sqlite'.format(db_id)) if db_dir else None
            else:
                db_id_parts = db_id.rsplit('_', 1)
                if len(db_id_parts) > 1:
                    m_suffix_pattern = re.compile('m\d+')
                    m_suffix = db_id_parts[1]
                    if re.fullmatch(m_suffix_pattern, m_suffix):
                        db_base_id = db_id_parts[0]
                    else:
                        db_base_id = db_id
                else:
                    db_base_id = db_id_parts[0]
                db_path = os.path.join(
                    db_dir, db_base_id,
                    '{}.sqlite'.format(db_base_id)) if db_dir else None
            schema_graph = SchemaGraph(db_id, db_path)
            if db_dir is not None:
                schema_graph.compute_field_picklist()
            schema_graph.load_data_from_spider_json(db_content)
            schema_graphs.index_schema_graph(schema_graph)
        print('{} schema graphs loaded'.format(schema_graphs.size))

    if augment_with_wikisql:
        parent_dir = os.path.dirname(data_dir)
        wikisql_dir = os.path.join(parent_dir, 'wikisql1.1')
        wikisql_schema_graphs = load_schema_graphs_wikisql(wikisql_dir)
        for db_id in range(wikisql_schema_graphs.size):
            schema_graph = wikisql_schema_graphs.get_schema(db_id)
            schema_graphs.index_schema_graph(schema_graph)
        print('{} schema graphs loaded (+wikisql)'.format(schema_graphs.size))

    return schema_graphs
                           table_values=table_values,
                           execute_result=execute_result,
                           sql=sql)


if __name__ == '__main__':
    with torch.no_grad():
        args.model_id = utils.model_index[args.model]
        get_model_dir(args)

        t2sql = Text2SQLWrapper(args, cs_args, None)
        schemas = {}
        for db_name in SQLForm.database.kwargs['choices']:
            db_path = os.path.join(args.db_dir, db_name,
                                   '{}.sqlite'.format(db_name))
            schema = SchemaGraph(db_name, db_path=db_path)

            import json

            in_json = os.path.join(args.data_dir, 'tables.json')
            with open(in_json) as f:
                tables = json.load(f)
            for table in tables:
                if table['db_id'] == db_name:
                    break

            schema.load_data_from_spider_json(table)
            t2sql.add_schema(schema)
            schemas[db_name] = schema

        app.run(host='0.0.0.0', port=8080)