示例#1
0
def transform_semQL_to_sql(schemas, sem_ql_prediction, output_dir):

    # TODO: find out if this adds any benefit for the trained models. If we run it with the ground truth (so no prediction, just SQL -> SemQL -> SQL) it is even slightly better without it.
    # alter_not_in(sem_ql_prediction, schemas=schemas)
    # alter_inter(sem_ql_prediction)
    alter_column0(sem_ql_prediction)

    index = range(len(sem_ql_prediction))
    count = 0
    exception_count = 0
    with open(os.path.join(output_dir, 'output.txt'), 'w',
              encoding='utf8') as d, open(os.path.join(output_dir,
                                                       'ground_truth.txt'),
                                          'w',
                                          encoding='utf8') as g:
        for i in index:
            try:
                result = transform(sem_ql_prediction[i],
                                   schemas[sem_ql_prediction[i]['db_id']])
                d.write(result[0] + '\n')
                g.write("%s\t%s\t%s\n" % (sem_ql_prediction[i]['query'],
                                          sem_ql_prediction[i]["db_id"],
                                          sem_ql_prediction[i]["question"]))
                count += 1
            except Exception as e:
                # This origin seems to be the fallback-query. Not sure how we come up with it, most probably it's just a dummy query to fill in a result for each example.
                result = transform(
                    sem_ql_prediction[i],
                    schemas[sem_ql_prediction[i]['db_id']],
                    origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)')
                exception_count += 1
                d.write(result[0] + '\n')
                g.write("%s\t%s\t%s\n" % (sem_ql_prediction[i]['query'],
                                          sem_ql_prediction[i]["db_id"],
                                          sem_ql_prediction[i]["question"]))
                count += 1
                # print(e)
                print('Exception')
                print(traceback.format_exc())
                print(sem_ql_prediction[i]['question'])
                print(sem_ql_prediction[i]['query'])
                print(sem_ql_prediction[i]['db_id'])
                print('===\n\n')

    return count, exception_count
示例#2
0
def _semql_to_sql(prediction, schemas):
    alter_column0([prediction])
    result = transform(prediction, schemas[prediction['db_id']])
    return result[0]