def load_schema_graphs_ask_data(data_dir): datasets = [ 'airbnb_san_francisco', 'airbnb_seattle', 'sports_salaries', 'wines' ] schema_graphs = SchemaGraphs() for dataset in datasets: in_csv = os.path.join(data_dir, 'raw/{}.csv'.format(dataset)) schema_graph = SchemaGraph(dataset) schema_graph.load_data_from_csv_file(in_csv) schema_graphs.index_schema_graph(schema_graph) schema_graph.pretty_print() print('{} schema graphs loaded'.format(schema_graphs.size))
def load_schema_graphs(args): """ Load database schema as a graph. """ dataset_name = args.dataset_name if dataset_name in ['spider', 'spider_ut']: return load_schema_graphs_spider( args.data_dir, dataset_name, db_dir=args.db_dir, augment_with_wikisql=args.augment_with_wikisql) if dataset_name == 'wikisql': return load_schema_graphs_wikisql(args.data_dir) in_csv = os.path.join(args.data_dir, '{}-schema.csv'.format(dataset_name)) schema_graphs = SchemaGraphs() schema_graph = SchemaGraph(dataset_name) schema_graph.load_data_from_finegan_dollak_csv_file(in_csv) schema_graphs.index_schema_graph(schema_graph) return schema_graphs
def load_schema_graphs_wikisql(data_dir, splits=['train', 'dev', 'test']): schema_graphs = SchemaGraphs() for split in splits: in_jsonl = os.path.join(data_dir, '{}.tables.jsonl'.format(split)) db_count = 0 with open(in_jsonl) as f: for line in f: table = json.loads(line.strip()) db_name = table['id'] schema_graph = WikiSQLSchemaGraph(db_name, table, caseless=False) schema_graph.id = table['id'] schema_graph.load_data_from_wikisql_json(table) schema_graph.compute_field_picklist(table) schema_graphs.index_schema_graph(schema_graph) db_count += 1 print('{} databases in {}'.format(db_count, split)) print('{} databases loaded in total'.format(schema_graphs.size)) return schema_graphs
def __init__(self, args, cs_args, schema, ensemble_model_dirs=None): self.args = args self.text_tokenize, _, _, self.tu = tok.get_tokenizers(args) # Vocabulary self.vocabs = data_loader.load_vocabs(args) # Confusion span detector self.confusion_span_detector = load_confusion_span_detector(cs_args) # Text-to-SQL model self.semantic_parsers = [] self.model_ensemble = None if ensemble_model_dirs is None: sp = load_semantic_parser(args) sp.schema_graphs = SchemaGraphs() self.semantic_parsers.append(sp) else: sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs] for i, model_dir in enumerate(ensemble_model_dirs): checkpoint_path = os.path.join(model_dir, 'model-best.16.tar') sps[i].schema_graphs = SchemaGraphs() sps[i].load_checkpoint(checkpoint_path) sps[i].cuda() sps[i].eval() self.semantic_parsers = sps self.model_ensemble = [sp.mdl for sp in sps] if schema is not None: self.add_schema(schema) self.model_ensemble = None # When generating SQL in execution order, cache reordered SQLs to save time if args.process_sql_in_execution_order: self.pred_restored_cache = self.semantic_parsers[0].load_pred_restored_cache() else: self.pred_restored_cache = None
def load_schema_graphs_spider(data_dir, dataset_name, db_dir=None, augment_with_wikisql=False): """ Load indexed database schema. """ in_json = os.path.join(data_dir, 'tables.json') schema_graphs = SchemaGraphs() with open(in_json) as f: content = json.load(f) for db_content in content: db_id = db_content['db_id'] if dataset_name == 'spider': db_path = os.path.join( db_dir, db_id, '{}.sqlite'.format(db_id)) if db_dir else None else: db_id_parts = db_id.rsplit('_', 1) if len(db_id_parts) > 1: m_suffix_pattern = re.compile('m\d+') m_suffix = db_id_parts[1] if re.fullmatch(m_suffix_pattern, m_suffix): db_base_id = db_id_parts[0] else: db_base_id = db_id else: db_base_id = db_id_parts[0] db_path = os.path.join( db_dir, db_base_id, '{}.sqlite'.format(db_base_id)) if db_dir else None schema_graph = SchemaGraph(db_id, db_path) if db_dir is not None: schema_graph.compute_field_picklist() schema_graph.load_data_from_spider_json(db_content) schema_graphs.index_schema_graph(schema_graph) print('{} schema graphs loaded'.format(schema_graphs.size)) if augment_with_wikisql: parent_dir = os.path.dirname(data_dir) wikisql_dir = os.path.join(parent_dir, 'wikisql1.1') wikisql_schema_graphs = load_schema_graphs_wikisql(wikisql_dir) for db_id in range(wikisql_schema_graphs.size): schema_graph = wikisql_schema_graphs.get_schema(db_id) schema_graphs.index_schema_graph(schema_graph) print('{} schema graphs loaded (+wikisql)'.format(schema_graphs.size)) return schema_graphs