예제 #1
0
    def __init__(self, duorat: DuoratAPI, db_path: str, schema_path: Optional[str]):
        self.duorat = duorat
        self.db_path = db_path

        if self.db_path.endswith(".sqlite"):
            pass
        elif self.db_path.endswith(".csv"):
            self.db_path = convert_csv_to_sqlite(self.db_path)
        else:
            raise ValueError("expected either .sqlite or .csv file")

        # Get SQLSchema
        if schema_path:
            schemas, _ = load_tables([schema_path])
            if len(schemas) != 1:
                raise ValueError()
            self.schema: Dict = next(iter(schemas.values()))
        else:
            self.schema: Dict = dump_db_json_schema(self.db_path, "")
            self.schema: SpiderSchema = schema_dict_to_spider_schema(
                refine_schema_names(self.schema)
            )

        self.preprocessed_schema: SQLSchema = preprocess_schema_uncached(
            schema=self.schema,
            db_path=self.db_path,
            tokenize=self.duorat.preproc._schema_tokenize,
        )
예제 #2
0
    def validate_item(
            self, item: SpiderItem,
            section: str) -> Tuple[bool, Optional[AbstractSyntaxTree]]:
        if item.spider_schema.db_id not in self.sql_schemas:
            self.sql_schemas[
                item.spider_schema.db_id] = preprocess_schema_uncached(
                    schema=item.spider_schema,
                    db_path=item.db_path,
                    tokenize=self._schema_tokenize,
                )

        try:
            if isinstance(item, SpiderItem) and isinstance(
                    self.transition_system, SpiderTransitionSystem):
                asdl_ast = self.transition_system.surface_code_to_ast(
                    code=item.spider_sql)
            else:
                raise NotImplementedError
            return True, asdl_ast
        except Exception as e:
            if "train" not in section:
                raise e
                return True, None
            else:
                raise e
예제 #3
0
    data_config = json.loads(
        _jsonnet.evaluate_file(args.data_config,
                               tla_codes={'prefix': '"data/"'}))
    if data_config['name'] != 'spider':
        raise ValueError()
    del data_config['name']
    if args.questions:
        data_config['paths'] = [args.questions]
    dataset = SpiderDataset(**data_config)

    sql_schemas = {}
    for db_id in dataset.schemas:
        spider_schema = dataset.schemas[db_id]
        sql_schemas[db_id] = preprocess_schema_uncached(
            schema=spider_schema,
            db_path=dataset.get_db_path(db_id),
            tokenize=api.preproc._schema_tokenize,
        )

    if args.output_spider and os.path.exists(args.output_spider):
        os.remove(args.output_spider)

    output_items = []
    for item in tqdm.tqdm(dataset):
        db_id = item.spider_schema.db_id
        result = api.infer_query(item.question, item.spider_schema,
                                 sql_schemas[db_id])
        print("QUESTION:", item.question)
        print("SLML:")
        print(pretty_format_slml(result['slml_question']))
        print("PREDICTION:", result['query'])