def demo(args): """ Interactive command line demo. Specify a target database from the Spider dataset and query the database using natural language. The output includes: 1. if the input question is translated to the SQL query, return the SQL query 2. otherwise, return a confusion span in the question that caused the input to be untranslatable. """ data_dir = 'data/' if args.demo_db is None: print('Error: must specify a database name to proceed') return else: db_name = args.demo_db db_path = os.path.join(args.db_dir, db_name, '{}.sqlite'.format(db_name)) schema = SchemaGraph(db_name, db_path=db_path) if db_name == 'covid_19': in_csv = os.path.join(data_dir, db_name, '{}.csv'.format(db_name)) in_type = os.path.join(data_dir, db_name, '{}.types'.format(db_name)) schema.load_data_from_csv_file(in_csv, in_type) else: # TODO: currently the demo is configured for the Spider dataset. import json in_json = os.path.join(args.data_dir, 'tables.json') with open(in_json) as f: tables = json.load(f) for table in tables: if table['db_id'] == db_name: break schema.load_data_from_spider_json(table) schema.pretty_print() if args.ensemble_inference: t2sql = Text2SQLWrapper(args, cs_args, schema, ensemble_model_dirs=ensemble_model_dirs) else: t2sql = Text2SQLWrapper(args, cs_args, schema) sys.stdout.write('Enter a natural language question: ') sys.stdout.write('> ') sys.stdout.flush() text = sys.stdin.readline() while text: output = t2sql.process(text, schema.name) translatable = output['translatable'] sql_query = output['sql_query'] confusion_span = output['confuse_span'] replacement_span = output['replace_span'] print('Translatable: {}'.format(translatable)) print('SQL: {}'.format(sql_query)) print('Confusion span: {}'.format(confusion_span)) print('Replacement span: {}'.format(replacement_span)) sys.stdout.flush() sys.stdout.write('\nEnter a natural language question: ') sys.stdout.write('> ') text = sys.stdin.readline()
def load_schema_graphs(args): """ Load database schema as a graph. """ dataset_name = args.dataset_name if dataset_name in ['spider', 'spider_ut']: return load_schema_graphs_spider( args.data_dir, dataset_name, db_dir=args.db_dir, augment_with_wikisql=args.augment_with_wikisql) if dataset_name == 'wikisql': return load_schema_graphs_wikisql(args.data_dir) in_csv = os.path.join(args.data_dir, '{}-schema.csv'.format(dataset_name)) schema_graphs = SchemaGraphs() schema_graph = SchemaGraph(dataset_name) schema_graph.load_data_from_finegan_dollak_csv_file(in_csv) schema_graphs.index_schema_graph(schema_graph) return schema_graphs
def load_schema_graphs_ask_data(data_dir): datasets = [ 'airbnb_san_francisco', 'airbnb_seattle', 'sports_salaries', 'wines' ] schema_graphs = SchemaGraphs() for dataset in datasets: in_csv = os.path.join(data_dir, 'raw/{}.csv'.format(dataset)) schema_graph = SchemaGraph(dataset) schema_graph.load_data_from_csv_file(in_csv) schema_graphs.index_schema_graph(schema_graph) schema_graph.pretty_print() print('{} schema graphs loaded'.format(schema_graphs.size))
def load_schema_graphs_spider(data_dir, dataset_name, db_dir=None, augment_with_wikisql=False): """ Load indexed database schema. """ in_json = os.path.join(data_dir, 'tables.json') schema_graphs = SchemaGraphs() with open(in_json) as f: content = json.load(f) for db_content in content: db_id = db_content['db_id'] if dataset_name == 'spider': db_path = os.path.join( db_dir, db_id, '{}.sqlite'.format(db_id)) if db_dir else None else: db_id_parts = db_id.rsplit('_', 1) if len(db_id_parts) > 1: m_suffix_pattern = re.compile('m\d+') m_suffix = db_id_parts[1] if re.fullmatch(m_suffix_pattern, m_suffix): db_base_id = db_id_parts[0] else: db_base_id = db_id else: db_base_id = db_id_parts[0] db_path = os.path.join( db_dir, db_base_id, '{}.sqlite'.format(db_base_id)) if db_dir else None schema_graph = SchemaGraph(db_id, db_path) if db_dir is not None: schema_graph.compute_field_picklist() schema_graph.load_data_from_spider_json(db_content) schema_graphs.index_schema_graph(schema_graph) print('{} schema graphs loaded'.format(schema_graphs.size)) if augment_with_wikisql: parent_dir = os.path.dirname(data_dir) wikisql_dir = os.path.join(parent_dir, 'wikisql1.1') wikisql_schema_graphs = load_schema_graphs_wikisql(wikisql_dir) for db_id in range(wikisql_schema_graphs.size): schema_graph = wikisql_schema_graphs.get_schema(db_id) schema_graphs.index_schema_graph(schema_graph) print('{} schema graphs loaded (+wikisql)'.format(schema_graphs.size)) return schema_graphs
table_values=table_values, execute_result=execute_result, sql=sql) if __name__ == '__main__': with torch.no_grad(): args.model_id = utils.model_index[args.model] get_model_dir(args) t2sql = Text2SQLWrapper(args, cs_args, None) schemas = {} for db_name in SQLForm.database.kwargs['choices']: db_path = os.path.join(args.db_dir, db_name, '{}.sqlite'.format(db_name)) schema = SchemaGraph(db_name, db_path=db_path) import json in_json = os.path.join(args.data_dir, 'tables.json') with open(in_json) as f: tables = json.load(f) for table in tables: if table['db_id'] == db_name: break schema.load_data_from_spider_json(table) t2sql.add_schema(schema) schemas[db_name] = schema app.run(host='0.0.0.0', port=8080)