Ejemplo n.º 1
0
def handle_request0(request):
    debug = 'debug' in request.form
    base = ""
    try:
        if not 'csv' in request.files:
            raise Exception('please include a csv file')
        if not 'q' in request.form:
            raise Exception(
                'please include a q parameter with a question in it')
        csv = request.files['csv']
        q = request.form['q']
        table_id = os.path.splitext(csv.filename)[0]
        table_id = re.sub(r'\W+', '_', table_id)

        # it would be easy to do all this in memory but I'm lazy
        stream = io.StringIO(csv.stream.read().decode("UTF8"), newline=None)
        base = table_id + "_" + str(uuid.uuid4())
        add_csv.csv_stream_to_sqlite(table_id, stream, base + '.db')
        stream.seek(0)
        record = add_csv.csv_stream_to_json(table_id, stream,
                                            base + '.tables.jsonl')
        stream.seek(0)
        add_question.question_to_json(table_id, q, base + '.jsonl')
        annotation = annotate_ws.annotate_example_ws(
            add_question.encode_question(table_id, q), record)
        with open(base + '_tok.jsonl', 'a+') as fout:
            fout.write(json.dumps(annotation) + '\n')

        message = run_split(base)
        code = 200

        if not debug:
            os.remove(base + '.db')
            os.remove(base + '.jsonl')
            os.remove(base + '.tables.jsonl')
            os.remove(base + '_tok.jsonl')
            os.remove('results_' + base + '.jsonl')
            if 'result' in message:
                message = message['result'][0]
                del message['query']
                del message['nlu']
                del message['table_id']
                message['params'] = message['sql_with_params'][1]
                message['sql'] = message['sql_with_params'][0]
                del message['sql_with_params']

    except Exception as e:
        message = {"error": str(e)}
        code = 500

    if debug:
        message['base'] = base

    return jsonify(message), code
Ejemplo n.º 2
0
def handle_request0(request):
    debug = 'debug' in request.form
    base = ""
    try:
        csv_key = 'csv'
        if csv_key not in request.files:
            csv_key = 'csv[]'
        print(request.files)
        if csv_key not in request.files and not 'sqlite' in request.files:
            raise Exception('please include a csv file or sqlite file')
        if not 'q' in request.form:
            raise Exception(
                'please include a q parameter with a question in it')
        csvs = request.files.getlist(csv_key)
        sqlite_file = request.files.get('sqlite')
        q = request.form['q']

        # brute force removal of any old requests
        if not TRIAL_RUN:
            subprocess.run(["bash", "-c", "rm -rf /cache/case_*"])
        key = "case_" + str(uuid.uuid4())
        data_dir = os.path.join('/cache', key)
        os.makedirs(os.path.join(data_dir, 'data'), exist_ok=True)
        os.makedirs(os.path.join(data_dir, 'original', 'database', 'data'),
                    exist_ok=True)
        print("Key", key)
        for csv in csvs:
            print("Working on", csv)
            table_id = os.path.splitext(csv.filename)[0]
            table_id = re.sub(r'\W+', '_', table_id)
            stream = io.StringIO(csv.stream.read().decode("UTF8"),
                                 newline=None)
            add_csv.csv_stream_to_sqlite(
                table_id, stream, os.path.join(data_dir, 'data',
                                               'data.sqlite'))
            stream.seek(0)
        if sqlite_file:
            print("Working on", sqlite_file)
            sqlite_file.save(os.path.join(data_dir, 'data', 'data.sqlite'))
        question_file = os.path.join(data_dir, 'question.json')
        tables_file = os.path.join(data_dir, 'tables.json')
        dummy_file = os.path.join(data_dir, 'dummy.json')
        add_question.question_to_json('data', q, question_file)

        row = {
            'question': q,
            'query': 'DUMMY',
            'db_id': args.database,
            'question_toks': _tokenize_question(tokenizer, q)
        }

        print(
            colored(
                f"question has been tokenized to : { row['question_toks'] }",
                'cyan',
                attrs=['bold']))

        with open(dummy_file, 'w') as fout:
            fout.write('[]\n')

        subprocess.run([
            "python", "/spider/preprocess/get_tables.py", data_dir,
            tables_file, dummy_file
        ])

        # valuenet expects different setup to irnet
        shutil.copyfile(tables_file,
                        os.path.join(data_dir, 'original', 'tables.json'))
        database_path = os.path.join(data_dir, 'original', 'database', 'data',
                                     'data.sqlite')
        shutil.copyfile(os.path.join(data_dir, 'data', 'data.sqlite'),
                        database_path)

        schemas_raw, schemas_dict = spider_utils.load_schema(data_dir)

        data, table = merge_data_with_schema(schemas_raw, [row])

        pre_processed_data = process_datas(data, related_to_concept,
                                           is_a_concept)

        pre_processed_with_values = _pre_process_values(pre_processed_data[0])

        print(
            f"we found the following potential values in the question: {row['values']}"
        )

        prediction, example = _inference_semql(pre_processed_with_values,
                                               schemas_dict, model)

        print(
            f"Results from schema linking (question token types): {example.src_sent}"
        )
        print(
            f"Results from schema linking (column types): {example.col_hot_type}"
        )

        print(
            colored(f"Predicted SemQL-Tree: {prediction['model_result']}",
                    'magenta',
                    attrs=['bold']))
        print()
        sql = _semql_to_sql(prediction, schemas_dict)

        print(colored(f"Transformed to SQL: {sql}", 'cyan', attrs=['bold']))
        print()
        result = _execute_query(sql, database_path)

        print(f"Executed on the database '{args.database}'. Results: ")
        for row in result:
            print(colored(row, 'green'))

        message = {
            "split": key,
            "result": {
                "sql": sql.strip(),
                "answer": result
            }
        }
        code = 200
    except Exception as e:
        message = {"error": str(e)}
        code = 500
    if debug:
        message['base'] = base
    return jsonify(message), code