def __init__(self, duorat: DuoratAPI, db_path: str, schema_path: Optional[str]): self.duorat = duorat self.db_path = db_path if self.db_path.endswith(".sqlite"): pass elif self.db_path.endswith(".csv"): self.db_path = convert_csv_to_sqlite(self.db_path) else: raise ValueError("expected either .sqlite or .csv file") # Get SQLSchema if schema_path: schemas, _ = load_tables([schema_path]) if len(schemas) != 1: raise ValueError() self.schema: Dict = next(iter(schemas.values())) else: self.schema: Dict = dump_db_json_schema(self.db_path, "") self.schema: SpiderSchema = schema_dict_to_spider_schema( refine_schema_names(self.schema) ) self.preprocessed_schema: SQLSchema = preprocess_schema_uncached( schema=self.schema, db_path=self.db_path, tokenize=self.duorat.preproc._schema_tokenize, )
def validate_item( self, item: SpiderItem, section: str) -> Tuple[bool, Optional[AbstractSyntaxTree]]: if item.spider_schema.db_id not in self.sql_schemas: self.sql_schemas[ item.spider_schema.db_id] = preprocess_schema_uncached( schema=item.spider_schema, db_path=item.db_path, tokenize=self._schema_tokenize, ) try: if isinstance(item, SpiderItem) and isinstance( self.transition_system, SpiderTransitionSystem): asdl_ast = self.transition_system.surface_code_to_ast( code=item.spider_sql) else: raise NotImplementedError return True, asdl_ast except Exception as e: if "train" not in section: raise e return True, None else: raise e
data_config = json.loads( _jsonnet.evaluate_file(args.data_config, tla_codes={'prefix': '"data/"'})) if data_config['name'] != 'spider': raise ValueError() del data_config['name'] if args.questions: data_config['paths'] = [args.questions] dataset = SpiderDataset(**data_config) sql_schemas = {} for db_id in dataset.schemas: spider_schema = dataset.schemas[db_id] sql_schemas[db_id] = preprocess_schema_uncached( schema=spider_schema, db_path=dataset.get_db_path(db_id), tokenize=api.preproc._schema_tokenize, ) if args.output_spider and os.path.exists(args.output_spider): os.remove(args.output_spider) output_items = [] for item in tqdm.tqdm(dataset): db_id = item.spider_schema.db_id result = api.infer_query(item.question, item.spider_schema, sql_schemas[db_id]) print("QUESTION:", item.question) print("SLML:") print(pretty_format_slml(result['slml_question'])) print("PREDICTION:", result['query'])