Ejemplo n.º 1
0
def process_examples(fexample, ftable, fout):
    print('annotating {}'.format(fexample))
    with open(fexample) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
        print('loading tables')
        tables = {}
        for line in tqdm(ft, total=count_lines(ftable)):
            d = json.loads(line)
            tables[d['id']] = d
        print('loading examples')
        n_written = 0
        for line in tqdm(fs, total=count_lines(fexample)):
            d = json.loads(line)
            a = annotate_example(d, tables[d['table_id']])
            if not is_valid_example(a):
                # raise Exception(str(a))
                print('Invalid example: {}'.format(str(a)))
                continue

            gold = Query.from_tokenized_dict(a['query'])
            reconstruct = Query.from_sequence(a['seq_output'], a['table'], lowercase=True)
            if gold.lower() != reconstruct.lower():
                raise Exception ('Expected:\n{}\nGot:\n{}'.format(gold, reconstruct))
            fo.write(json.dumps(a) + '\n')
            n_written += 1
        print('wrote {} examples'.format(n_written))
Ejemplo n.º 2
0
 def generate_query(self, db, max_cond=4):
     max_cond = min(len(self.header), max_cond)
     # sample a select column
     sel_index = random.choice(list(range(len(self.header))))
     # sample where conditions
     query = Query(-1, Query.agg_ops.index(''))
     results = self.execute_query(db, query)
     condition_options = list(range(len(self.header)))
     condition_options.remove(sel_index)
     for i in range(max_cond):
         if not results:
             break
         cond_index = random.choice(condition_options)
         if self.types[cond_index] == 'text':
             cond_op = Query.cond_ops.index('=')
         else:
             cond_op = random.choice(list(range(len(Query.cond_ops))))
         cond_val = random.choice([r[cond_index] for r in results])
         query.conditions.append((cond_index, cond_op, cond_val))
         new_results = self.execute_query(db, query)
         if [r[sel_index]
                 for r in new_results] != [r[sel_index] for r in results]:
             condition_options.remove(cond_index)
             results = new_results
         else:
             query.conditions.pop()
     # sample an aggregation operation
     if self.types[sel_index] == 'text':
         query.agg_index = Query.agg_ops.index('')
     else:
         query.agg_index = random.choice(list(range(len(Query.agg_ops))))
     query.sel_index = sel_index
     results = self.execute_query(db, query)
     return query, results
Ejemplo n.º 3
0
def eval_one_qelos(db_file, pred_file, source_file):
    engine = DBEngine(db_file)
    exact_match = []
    with open(source_file) as fs, open(pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(source_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            qp = None
            try:
                qp = Query.from_dict(ep)
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
            correct = pred == gold
            match = qp == qg
            grades.append(correct)
            exact_match.append(match)
        result = {
            'ex_accuracy': sum(grades) / len(grades),
            'lf_accuracy': sum(exact_match) / len(exact_match),
            }
        return result
Ejemplo n.º 4
0
def hello():
    pred_json = json.loads(request.args.get('pred'))
    gold_json = json.loads(request.args.get('gold'))
    dataset = request.args.get('dataset')

    engine = DBEngine(os.path.join(DATABASE_PATH, "{}.db".format(dataset)))

    exact_match = []
    grades = []

    for lp, ls in tqdm(zip(pred_json, gold_json), total=len(gold_json)):
        eg = ls
        ep = lp
        qg = Query.from_dict(eg['sql'])
        gold = engine.execute_query(eg['table_id'], qg, lower=True)
        pred = ep['error']
        qp = None
        if not ep['error']:
            try:
                qp = Query.from_dict(ep['query'])
                pred = engine.execute_query(eg['table_id'], qp, lower=True)
            except Exception as e:
                pred = repr(e)
        correct = pred == gold
        match = qp == qg
        grades.append(correct)
        exact_match.append(match)

    ex_accuracy = sum(grades) / len(grades)
    lf_accuracy = sum(exact_match) / len(exact_match)
    return json.dumps({"ex_accuracy": ex_accuracy, "lf_accuracy": lf_accuracy})
Ejemplo n.º 5
0
    def eval(self, gold, sql_gold, engine):
        self.exception_raised = False # reset this flag for a new evaluation
        if self.agg == gold['query']['agg']:
            self.correct['agg'] = 1

        if self.sel == gold['query']['sel']:
            self.correct['sel'] = 1

        op_list_pred = [op for col, op, span in self.cond]
        op_list_gold = [op for col, op, span in gold['query']['conds']]

        col_list_pred = [col for col, op, span in self.cond]
        col_list_gold = [col for col, op, span in gold['query']['conds']]

        q = gold['question']['words']
        span_list_pred = [' '.join(q[span[0]:span[1] + 1])
                          for col, op, span in self.cond]
        span_list_gold = [' '.join(span['words'])
                          for col, op, span in gold['query']['conds']]

        where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred))
        where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold))
        where_pred.sort()
        where_gold.sort()
        if where_pred == where_gold and (len(col_list_pred) == len(col_list_gold)) and (len(op_list_pred) == len(op_list_gold)) and (len(span_list_pred) == len(span_list_gold)):
            self.correct['where'] = 1

        if (len(col_list_pred) == len(col_list_gold)) and ([it[0] for it in where_pred] == [it[0] for it in where_gold]):
            self.correct['col'] = 1

        if (len(op_list_pred) == len(op_list_gold)) and ([it[1] for it in where_pred] == [it[1] for it in where_gold]):
            self.correct['lay'] = 1

        if (len(span_list_pred) == len(span_list_gold)) and ([it[2] for it in where_pred] == [it[2] for it in where_gold]):
            self.correct['span'] = 1

        if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))):
            self.correct['all'] = 1

        # execution
        table_id = gold['table_id']
        ans_gold = engine.execute_query(
            table_id, Query.from_dict(sql_gold), lower=True)
        try:
            sql_pred = {'agg':self.agg, 'sel':self.sel, 'conds': self.recover_cond_to_gloss(gold)}
            ans_pred = engine.execute_query(
                table_id, Query.from_dict(sql_pred), lower=True)
        except Exception as e:

            #pdb.set_trace()

            self.exception_raised = True
            ans_pred = repr(e)
        if set(ans_gold) == set(ans_pred):
            self.correct['exe'] = 1
Ejemplo n.º 6
0
def main(anno_file_name, col_headers, raw_args=None, verbose=True):
    parser = argparse.ArgumentParser(description='evaluate.py')
    opts.translate_opts(parser)
    opt = parser.parse_args(raw_args)
    torch.cuda.set_device(opt.gpu)
    opt.db_file = os.path.join(opt.data_path, '{}.db'.format(opt.split))
    opt.pre_word_vecs = os.path.join(opt.data_path, 'embedding')
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    opts.train_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    opt.anno = anno_file_name

    engine = DBEngine(opt.db_file)

    js_list = table.IO.read_anno_json(opt.anno)

    prev_best = (None, None)
    sql_query = []
    for fn_model in glob.glob(opt.model_path):

        opt.model = fn_model

        translator = Translator(opt, dummy_opt.__dict__)
        data = table.IO.TableDataset(js_list, translator.fields, None, False)
        test_data = table.IO.OrderedIterator(dataset=data,
                                             device=opt.gpu,
                                             batch_size=opt.batch_size,
                                             train=False,
                                             sort=True,
                                             sort_within_batch=False)

        # inference
        r_list = []
        for batch in test_data:
            r_list += translator.translate(batch)
        r_list.sort(key=lambda x: x.idx)
        pred = r_list[-1]
        sql_pred = {
            'agg': pred.agg,
            'sel': pred.sel,
            'conds': pred.recover_cond_to_gloss(js_list[-1])
        }
        if verbose:
            print('\n sql_pred: ', sql_pred, '\n')
            print('\n col_headers: ', col_headers, '\n')
        sql_query = Query(sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
        try:
            ans_pred = engine.execute_query(js_list[-1]['table_id'],
                                            Query.from_dict(sql_pred),
                                            lower=True,
                                            verbose=verbose)
        except Exception as e:
            ans_pred = None
    return sql_query.get_complete_query(col_headers), ans_pred
Ejemplo n.º 7
0
 def test_query_all(self):
     self._prepare_query_test()
     files_2_3 = self.tie_core.query(
         Query([QUERY_TAG_2, QUERY_TAG_3], MatchType.all))
     self.assertEqual(
         [os.path.abspath(QUERY_FILE_2)], files_2_3,
         "Query all with 2 tags did not find the expected files")
     files_3 = self.tie_core.query(Query([QUERY_TAG_3], MatchType.all))
     self.assertEqual(
         [os.path.abspath(QUERY_FILE_2),
          os.path.abspath(QUERY_FILE_3)], files_3,
         "Query all with 1 tag did not find the expected files")
     _clean_after_query_test()
Ejemplo n.º 8
0
def _query(core, tags, match_type, frontend: Frontend):
    query = Query(tags, match_type)
    result_files = core.query(query)
    if len(result_files) == 0:
        frontend.show_message("No files found for your query!")
    else:
        print_out_list(result_files)
Ejemplo n.º 9
0
 def test_query_all_no_match(self):
     self._prepare_query_test()
     result = self.tie_core.query(
         Query([QUERY_TAG_2, "some random non existant tag"],
               MatchType.all))
     self.assertEqual([], result, "Query with no result")
     _clean_after_query_test()
Ejemplo n.º 10
0
def test():
    # convert query dict to text (without correct column references)
    details = {"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}
    test_str = Query(details["sel"], details["agg"], details["conds"])
    print(test_str)

    db = records.Database('sqlite:///data/train.db')
    conn = db.get_connection()

    # convert query dict to text with table reference (still does not give the correct columns)
    # because header is not supplied
    table = Table.from_db(conn, "1-1000181-1")
    print(table.query_str(test_str))

    # convert query dict to text with table reference after supplying headers
    table_data = {
        "id":
        "1-1000181-1",
        "header": [
            "State/territory", "Text/background colour", "Format",
            "Current slogan", "Current series", "Notes"
        ],
        "types": [],
        "rows": []
    }
    t = Table(table_data["id"], table_data["header"], table_data["types"],
              table_data["rows"])
    print(t.query_str(test_str))
Ejemplo n.º 11
0
 def get_query_from_json(self, json_line):
     """Returns a Query object for the json input and returns the table object as well"""
     q = Query.from_dict(json_line["sql"])
     t_id = json_line["table_id"]
     table = self.table_map[t_id]
     t = Table("", table["header"], table["types"], table["rows"])
     return t, q
Ejemplo n.º 12
0
def toQueryStr(file_name, table_arr, type=0, test_batch_size=1000):
    path = os.path.join(DATA_DIR, '{}.jsonl'.format(file_name))
    print(path)
    with open(path, 'r') as pf:
        data = pf.readlines()
    
    idxs = np.arange(len(data))
    data = np.array(data, dtype=np.object)

    np.random.seed(0)   # set random seed so that random things are reproducible
    np.random.shuffle(idxs)
    data = data[idxs]
    batched_data = chunked(data, test_batch_size)

    print("start processing")
    examples = []
    for batch_idx, batch_data in enumerate(batched_data):
        if len(batch_data) < test_batch_size:
            break # the last batch is smaller than the others, exclude.
        for d_idx, d in enumerate(batch_data): 
            line = json.loads(str(d), encoding='utf-8')
            doc_token = line['question']
            code_arr = line['sql']
            query = Query(code_arr['sel'], code_arr['agg'], code_arr['conds'])

            id = line['table_id']
            table = Table("table_id", "header", "types", "rows")
            code_str = ''
            for table in table_arr:
                if table.table_id == id:
                    table = table
                    code_str = table.query_str(query)
                    break
                else:
                    continue
            isNegative = np.random.randint(2)
            if isNegative == 0:
                random_line_num = np.random.randint(len(data))
                line = json.loads(str(data[random_line_num]), encoding='utf-8')
                doc_token = line['question']
                code_token = code_str
            else:
                code_token = code_str
            example = (str(isNegative), "nothing", "nothing", doc_token, code_token)
            example = '<CODESPLIT>'.join(example)
            examples.append(example)
    data_path = os.path.join(DATA_DIR, 'train_valid/wiki_sql')
    if not os.path.exists(data_path):
        os.makedirs(data_path)
    
    output_file_name = "1.txt"
    if type == 0:
        output_file_name = 'train.txt'
    else:
        output_file_name = 'valid.txt'
    file_path = os.path.join(data_path, output_file_name)
    print(file_path)
    with open(file_path, 'w', encoding='utf-8') as f:
        f.writelines('\n'.join(examples))
Ejemplo n.º 13
0
def main(argv):
    del argv  # Unused.

    db_file = join(FLAGS.data_root, FLAGS.db_file)
    parsed_std_sql_file = join(FLAGS.data_root, FLAGS.parsed_std_sql_file)
    parsed_pred_sql_file = join(FLAGS.data_root, FLAGS.parsed_pred_sql_file)

    engine = DBEngine(db_file)
    exact_match = []

    with open(parsed_std_sql_file) as fs, open(parsed_pred_sql_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp),
                           total=count_lines(parsed_std_sql_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)

            try:
                qg = Query.from_dict(eg['sql'])
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
            except Exception as e:
                gold = repr(e)

            #pred = ep['error']
            qp = None
            #if not ep['error']:
            if True:
                try:
                    qp = Query.from_dict(ep['sql'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            if pred == gold and qp != qg:
                print(qp)
                print(qg)
            grades.append(correct)
            exact_match.append(match)
        print(
            json.dumps(
                {
                    'ex_accuracy': sum(grades) / len(grades),
                    'lf_accuracy': sum(exact_match) / len(exact_match),
                },
                indent=2))
Ejemplo n.º 14
0
 def test_query_any(self):
     self._prepare_query_test()
     files = self.tie_core.query(
         Query([QUERY_TAG_2, QUERY_TAG_3], MatchType.any))
     self.assertEqual([
         os.path.abspath(QUERY_FILE_1),
         os.path.abspath(QUERY_FILE_2),
         os.path.abspath(QUERY_FILE_3)
     ], files, "Query any did not find the expected files")
     _clean_after_query_test()
Ejemplo n.º 15
0
def main():
    app = QtGui.QApplication(sys.argv)

    path = find_data_file('/data.db')
    engine = sql.create_engine('sqlite:///' + path)
    q = Query(engine)
    form = App(q)

    form.show()
    app.exec_()
Ejemplo n.º 16
0
def toQueryStr(file_name, table_dic, type=0, test_batch_size=1000):
    path = os.path.join(DATA_DIR, '{}.jsonl'.format(file_name))
    print(path)
    with open(path, 'r') as pf:
        data = pf.readlines()

    idxs = np.arange(len(data))
    data = np.array(data, dtype=np.object)

    np.random.seed(0)  # set random seed so that random things are reproducible
    np.random.shuffle(idxs)
    data = data[idxs]
    batched_data = chunked(data, test_batch_size)

    print("start processing")
    # tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    # tokenizer.add_special_tokens(False)
    examples = []
    for batch_idx, batch_data in enumerate(batched_data):
        if len(batch_data) < test_batch_size:
            break  # the last batch is smaller than the others, exclude.
        for d_idx, d in enumerate(batch_data):
            line = json.loads(str(d), encoding='utf-8')
            doc_str = line['question']
            code_arr = line['sql']
            query = Query(code_arr['sel'], code_arr['agg'], code_arr['conds'])

            id = line['table_id']
            if id in table_dic:
                table = table_dic[id]
                code_str = table.query_str(query)
                example = dict()
                example['docstring_tokens'] = tokenize_docstring_from_string(
                    doc_str)
                example['code_tokens'] = tokenize_docstring_from_string(
                    code_str)
                examples.append(json.dumps(example))

    data_path = os.path.join(DATA_DIR, 'wiki_sql')
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    output_file_name = "1.txt"
    if type == 0:
        output_file_name = 'train.jsonl'
    elif type == 2:
        output_file_name = 'test.jsonl'
    else:
        output_file_name = 'valid.jsonl'
    file_path = os.path.join(data_path, output_file_name)
    print(file_path)
    with open(file_path, 'w', encoding='utf-8') as f:
        f.writelines('\n'.join(examples))
Ejemplo n.º 17
0
def _gen_files(save_path=save_path):
    """
    Prepare files used for Machine Comprehension Binary Classifier
    """
    for split in ['train', 'test', 'dev']:
        print('------%s-------' % split)
        n = 0
        fsplit = os.path.join(data_path, split) + '.jsonl'
        ftable = os.path.join(data_path, split) + '.tables.jsonl'
        """
	test.txt original column content w/o truncate
	test_model.txt column name truncate or pad to length 3
	"""
        with open(fsplit) as fs, open(ftable) as ft, \
            open(os.path.join(save_path, split+'.txt'), mode='w') as fw, \
            open(os.path.join(save_path, '%s_model.txt'%split), mode='w') as fw_model, \
   open(os.path.join(save_path, split+'.lon'), mode='w') as fsql, \
            open(os.path.join(save_path, '%s.ori.qu'%split), 'w') as qu_file, \
            open(os.path.join(save_path, '%s.ori.lon'%split), 'w') as lon_file:
            print('loading tables...')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading tables done.')

            print('loading examples')
            f2v_all, v2f_all = [], []
            for line in tqdm(fs, total=count_lines(fsplit)):

                d = json.loads(line)
                Q = d['question']
                Q = _preclean(Q).replace('\t', '')

                qu_file.write(Q + '\n')

                q_sent = Query.from_dict(d['sql'])
                rows = tables[d['table_id']]['rows']
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                lon_file.write(S + '\n')

                rows = np.asarray(rows)
                fs = tables[d['table_id']]['header']
                all_fields = [_preclean(f) for f in fs]
                # all fields are sorted by length in descending order
                # for string match purpose
                headers = sorted(all_fields, key=len, reverse=True)

                f2v = defaultdict(list)  #f2v
                v2f = defaultdict(list)  #v2f
                for row in rows:
                    for i in range(len(fs)):
                        cur_f = _preclean(str(fs[i]))
                        cur_row = _preclean(str(row[i]))
                        #cur_f = cur_f.replace('\u2003',' ')
                        f2v[cur_f].append(cur_row)
                        if cur_f not in v2f[cur_row]:
                            v2f[cur_row].append(cur_f)
                f2v_all.append(f2v)
                v2f_all.append(v2f)

                #####################################
                ########## Annotate SQL #############
                #####################################
                q_sent = Query.from_dict(d['sql'])
                S, col_names, val_names = q_sent.to_sentence(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S = _preclean(S)

                S_noparen = q_sent.to_sentence_noparenthesis(
                    tables[d['table_id']]['header'], rows,
                    tables[d['table_id']]['types'])
                S_noparen = _preclean(S_noparen)

                col_names = [_preclean(col_name) for col_name in col_names]
                val_names = [_preclean(val_name) for val_name in val_names]

                HEAD = col_names[-1]
                S_head = _preclean(HEAD)

                #annotate for SQL
                name_pairs = []
                for col_name, val_name in zip(col_names, val_names):
                    if col_name == val_name:
                        name_pairs.append([_preclean(col_name), 'true'])
                    else:
                        name_pairs.append(
                            [_preclean(col_name),
                             _preclean(val_name)])

                # sort to compare with candidates
                name_pairs.sort(key=lambda x: x[1])
                fsql.write('#%d\n' % n)
                fw.write('#%d\n' % n)

                for f in col_names:
                    fsql.write(S.replace(f, '[' + f + ']') + '\n')
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = _truncate(f, END='<bos>', PAD='<pad>', max_len=-1)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw.write(s + '\n')
                    fsql.write(S + '\n')

                #if '\u2003' in Q:
                #    print('u2003: '+Q)
                #if '\xa0' in Q:
                #    print(n)
                #    print('xa0: '+Q)
                #    print(S)
                for f in col_names:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 1')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                for f in [f for f in headers if f not in col_names]:
                    f = f.replace(u'\xa0', u' ').replace('\t', '')
                    Q = Q.replace(u'\xa0', u' ').replace('\t', '')
                    f = _truncate(f, END='bos', PAD='pad', max_len=3)
                    s = (Q + '\t' + f + '\t 0')
                    assert len(s.split('\t')) == 3
                    fw_model.write(s + '\n')

                n += 1
            fsql.write('#%d\n' % n)
            fw.write('#%d\n' % n)

        scipy.savez(os.path.join(save_path, '%s_dict.npz' % split),
                    f2v_all=f2v_all,
                    v2f_all=v2f_all)
        print('num of records:%d' % n)
Ejemplo n.º 18
0
 def test_query_all_empty_tags_list(self):
     self._prepare_query_test()
     result = self.tie_core.query(Query([], MatchType.all))
     self.assertEqual([], result, "Query with no tags")
     _clean_after_query_test()
Ejemplo n.º 19
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--din', default=data_path, help='data directory')
    parser.add_argument('--dout', default='annotated', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.dout):
        os.makedirs(args.dout)

    for split in ['dev']:
        #for split in ['train', 'test', 'dev']:
        with open(save_path+'%s.qu'%split, 'w') as qu_file, open(save_path+'%s.lon'%split, 'w') as lon_file, \
            open(save_path+'%s.out'%split, 'w') as out, open(save_path+'%s_sym_pairs.txt'%split, 'w') as sym_file, \
            open(save_path+'%s_ground_truth.txt'%split, 'w') as S_file:

            fsplit = os.path.join(args.din, split) + '.jsonl'
            ftable = os.path.join(args.din, split) + '.tables.jsonl'

            with open(fsplit) as fs, open(ftable) as ft:
                print('loading tables')
                tables = {}
                for line in tqdm(ft, total=count_lines(ftable)):
                    d = json.loads(line)
                    tables[d['id']] = d
                print('loading tables done.')

                print('loading examples')
                n, acc, acc_pair, acc_all, error = 0, 0, 0, 0, 0
                target = -1

                ADD_FIELDS = False
                step2 = True

                for line in tqdm(fs, total=count_lines(fsplit)):
                    ADD_TO_FILE = True
                    d = json.loads(line)
                    Q = d['question']

                    rows = tables[d['table_id']]['rows']
                    rows = np.asarray(rows)
                    fs = tables[d['table_id']]['header']

                    all_fields = []
                    for f in fs:
                        all_fields.append(_preclean(f))
                    # all fields are sorted by length in descending order
                    # for string match purpose
                    all_fields = sorted(all_fields, key=len, reverse=True)

                    smap = defaultdict(list)  #f2v
                    reverse_map = defaultdict(list)  #v2f
                    for row in rows:
                        for i in range(len(fs)):
                            cur_f = _preclean(str(fs[i]))
                            cur_row = _preclean(str(row[i]))
                            smap[cur_f].append(cur_row)
                            if cur_f not in reverse_map[cur_row]:
                                reverse_map[cur_row].append(cur_f)

                    #----------------------------------------------------------
                    # all values are sorted by length in descending order
                    # for string match purpose
                    keys = sorted(reverse_map.keys(), key=len, reverse=True)

                    Q = _preclean(Q)
                    Q_ori = Q

                    #####################################
                    ########## Annotate question ########
                    #####################################
                    candidates, cond_fields = _match_pairs(
                        Q, Q_ori, keys, reverse_map)
                    Q_head, head2partial = _match_head(Q, Q_ori, smap,
                                                       all_fields, cond_fields)

                    Q, Qpairs = annotate_Q(Q, Q_ori, Q_head, candidates,
                                           all_fields, head2partial, n, target)
                    qu_file.write(Q + '\n')

                    validation_pairs = copy.copy(Qpairs)
                    validation_pairs.append((Q_head, '<f0>', 'head'))
                    for i, f in enumerate(all_fields):
                        validation_pairs.append((f, '<c' + str(i) + '>', 'c'))

                    #####################################
                    ########## Annotate SQL #############
                    #####################################
                    q_sent = Query.from_dict(d['sql'])
                    S, col_names, val_names = q_sent.to_sentence(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S = _preclean(S)
                    S_ori = S

                    S_noparen = q_sent.to_sentence_noparenthesis(
                        tables[d['table_id']]['header'], rows,
                        tables[d['table_id']]['types'])
                    S_noparen = _preclean(S_noparen)

                    col_names = [_preclean(col_name) for col_name in col_names]
                    val_names = [_preclean(val_name) for val_name in val_names]

                    HEAD = col_names[-1]
                    S_head = _preclean(HEAD)

                    #annotate for SQL
                    name_pairs = []
                    for col_name, val_name in zip(col_names, val_names):
                        if col_name == val_name:
                            name_pairs.append([_preclean(col_name), 'true'])
                        else:
                            name_pairs.append(
                                [_preclean(col_name),
                                 _preclean(val_name)])

                    # sort to compare with candidates
                    name_pairs.sort(key=lambda x: x[1])
                    new_name_pairs = [[
                        '<f' + str(i + 1) + '>', '<v' + str(i + 1) + '>'
                    ] for i, (field, value) in enumerate(name_pairs)]

                    # only annotate S while identified (f,v) pairs are right
                    if _equal(name_pairs, candidates):
                        pairs = []
                        for (f, v), (new_f,
                                     new_v) in zip(name_pairs, new_name_pairs):
                            pairs.append((f, new_f, 'f'))
                            pairs.append((v, new_v, 'v'))
                        # sort (word,symbol) pairs by length in descending order
                        pairs.sort(key=lambda x: 100 - len(x[0]))

                        for p, new_p, t in pairs:
                            cp = _backslash(p)

                            if new_p in Q:
                                if t == 'v':
                                    S = S.replace(p + ' )', new_p + ' )')
                                if t == 'f':
                                    S = re.sub(
                                        '\( ' + cp + ' (equal|less|greater)',
                                        '( ' + new_p + r' \1', S)

                    # only annotate S while identified HEAD is right
                    if S_head == Q_head and '<f0>' in Q:
                        S = S.replace(S_head, '<f0>')

                    # annote unseen fields
                    if ADD_FIELDS:
                        for i, f in enumerate(all_fields):
                            cf = _backslash(f)
                            S = re.sub('(\s|^)' + cf + '(\s|$|s)',
                                       ' <c' + str(i) + '> ', S)

                    S = _clean(S)
                    lon_file.write(S + '\n')

                    ###############################
                    ######### VALIDATION ##########
                    ###############################
                    recover_S = S
                    for word, sym, t in validation_pairs:
                        recover_S = recover_S.replace(sym, word)
                        sym_file.write(sym + '=>' + word + '<>')
                    sym_file.write('\n')

                    S_file.write(S_noparen + '\n')
                    #------------------------------------------------------------------------
                    if _equal(name_pairs, candidates):
                        acc_pair += 1

                    if Q_head == S_head:
                        acc += 1

                    if _equal(name_pairs, candidates) and Q_head == S_head:
                        acc_all += 1

                    full_anno = True
                    for s in S.split():
                        if s[0] != '<' and s not in [
                                '(', ')', 'where', 'less', 'greater', 'equal',
                                'max', 'min', 'count', 'sum', 'avg', 'and',
                                'true'
                        ]:
                            error += 1
                            full_anno = False
                            break

                    if False and not (_equal(name_pairs, candidates)
                                      and head == col_names[-1]):
                        print('--------' + str(n) + '-----------')
                        print(Q_ori)
                        print(Q)
                        print(S_ori)
                        print(S)
                        print('head:')
                        print(head)
                        print(head2partial)
                        print('true HEAD')
                        print(Q_head)
                        print('fields:')
                        print(candidates)
                        print(p2partial)
                        print('true fields')
                        print(name_pairs)

                    n += 1

                print('total number of examples:' + str(n))
                print('fully snnotated:' + str(1 - error * 1.0 / n))
                print('accurate all percent:' + str(acc_all * 1.0 / n))
                print('accurate HEAD match percent:' + str(acc * 1.0 / n))
                print('accurate fields pair match percent:' +
                      str(acc_pair * 1.0 / n))
Ejemplo n.º 20
0
    for split in ['train', 'dev', 'test']:
        fsplit = os.path.join(args.din, split) + '.jsonl'
        ftable = os.path.join(args.din, split) + '.tables.jsonl'
        fout = os.path.join(args.dout, split) + '.jsonl'

        print('annotating {}'.format(fsplit))
        with open(fsplit) as fs, open(ftable) as ft, open(fout, 'wt') as fo:
            print('loading tables')
            tables = {}
            for line in tqdm(ft, total=count_lines(ftable)):
                d = json.loads(line)
                tables[d['id']] = d
            print('loading examples')
            n_written = 0
            for line in tqdm(fs, total=count_lines(fsplit)):
                d = json.loads(line)
                a = annotate_example(d, tables[d['table_id']])
                if not is_valid_example(a):
                    raise Exception(str(a))

                gold = Query.from_tokenized_dict(a['query'])
                reconstruct = Query.from_sequence(a['seq_output'],
                                                  a['table'],
                                                  lowercase=True)
                if gold.lower() != reconstruct.lower():
                    raise Exception('Expected:\n{}\nGot:\n{}'.format(
                        gold, reconstruct))
                fo.write(json.dumps(a) + '\n')
                n_written += 1
            print('wrote {} examples'.format(n_written))
Ejemplo n.º 21
0
    parser.add_argument('--topk', type=int, default=3, help='k of top_k')
    
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
   
    temp = []
    
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        exact_match = []
        
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):            
            eg = json.loads(ls)
            qg = Query.from_dict(eg['sql'], ordered=args.ordered)
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            
            pred_topk = []
            qp_topk = []
            
            ep = json.loads(lp)
            pred = ep.get('error', None)
            qp = None
            for i in range(args.topk):
                if not ep.get('error', None):
                    try:
                        
                        if ep['query'][str(i)]['conds'] == [[]]:
                            ep['query'][str(i)]['conds'] = []
            
Ejemplo n.º 22
0
if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('source_file', help='source file for the prediction')
    parser.add_argument('db_file', help='source database for the prediction')
    parser.add_argument('pred_file', help='predictions by the model')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    exact_match = []
    with open(args.source_file) as fs, open(args.pred_file) as fp:
        grades = []
        for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.pred_file)):
            eg = json.loads(ls)
            ep = json.loads(lp)
            qg = Query.from_dict(eg['sql'])
            #print qg
            gold = engine.execute_query(eg['table_id'], qg, lower=True)
            pred = ep['error']
            qp = None
            if not ep['error']:
                try:
                    qp = Query.from_dict(ep['query'])
                    pred = engine.execute_query(eg['table_id'], qp, lower=True)
                except Exception as e:
                    pred = repr(e)
            correct = pred == gold
            match = qp == qg
            #if not match:
            #    print qp.to_dict()
            #    print qg.to_dict()
    def eval(self, split, gold, sql_gold, engine=None):
        if self.agg == gold['query']['agg']:
            self.correct['agg'] = 1

        if self.sel == gold['query']['sel']:
            self.correct['sel'] = 1

        #gold['BIO_label'].data, gold['BIO_column_label'].data

        op_list_pred = [op for col, op, span in self.cond]
        op_list_gold = [op for col, op, span in gold['query']['conds']]

        col_list_pred = [col for col, op, span in self.cond]
        col_list_gold = [col for col, op, span in gold['query']['conds']]

        q = gold['question']['words']
        span_list_pred = [
            ' '.join(q[span[0]:span[1] + 1]) for col, op, span in self.cond
        ]
        span_list_gold = [
            ' '.join(span['words']) for col, op, span in gold['query']['conds']
        ]

        where_pred = list(zip(col_list_pred, op_list_pred, span_list_pred))
        where_gold = list(zip(col_list_gold, op_list_gold, span_list_gold))
        where_pred.sort()
        where_gold.sort()
        if where_pred == where_gold and (
                len(col_list_pred) == len(col_list_gold)) and (
                    len(op_list_pred)
                    == len(op_list_gold)) and (len(span_list_pred)
                                               == len(span_list_gold)):
            self.correct['where'] = 1

        if (len(col_list_pred) == len(col_list_gold)) and ([
                it[0] for it in where_pred
        ] == [it[0] for it in where_gold]):
            self.correct['col'] = 1

        if (len(op_list_pred) == len(op_list_gold)) and ([
                it[1] for it in where_pred
        ] == [it[1] for it in where_gold]):
            self.correct['lay'] = 1

        if (len(span_list_pred) == len(span_list_gold)) and ([
                it[2] for it in where_pred
        ] == [it[2] for it in where_gold]):
            self.correct['span'] = 1

        if all((self.correct[it] == 1 for it in ('agg', 'sel', 'where'))):
            self.correct['all'] = 1

        # execution
        table_id = gold['table_id']
        ans_gold = '0'
        ans_pred = '1'
        if engine is not None:
            ans_gold = engine.execute_query(table_id,
                                            Query.from_dict(sql_gold),
                                            lower=True)

            try:
                sql_pred = {
                    'agg': self.agg,
                    'sel': self.sel,
                    'conds': self.recover_cond_to_gloss(gold)
                }
                ans_pred = engine.execute_query(table_id,
                                                Query.from_dict(sql_pred),
                                                lower=True)
            except Exception as e:
                ans_pred = repr(e)
        else:
            ans_gold = '0'
            ans_pred = '1'
        if set(ans_gold) == set(ans_pred):
            self.correct['exe'] = 1

        error_case = {}
        if split == 'finaltest':
            #if self.correct['where'] != 1:
            if True:
                error_case['sel'] = self.correct['sel']
                error_case['where'] = self.correct['where']
                error_case['all'] = self.correct['all']
                error_case['table_id'] = gold['table_id']
                error_case['question_id'] = gold['id']
                error_case['question'] = gold['question']['words']
                error_case['table_head'] = [
                    head['words'] for head in gold['table']['header']
                ]
                #error_case['table_content'] = gold['tbl_content']

                #                for i in range(len(sql_gold['conds'])):
                #                    sql_gold['conds'][i][0] = (
                #                    sql_gold['conds'][i][0], gold['table']['header'][sql_gold['conds'][i][0]]['words'],
                #                    gold['tbl_content'][sql_gold['conds'][i][0]])
                error_case['gold'] = sql_gold

                error_case['predict'] = {
                    'agg': self.agg.item(),
                    'sel': self.sel.item(),
                    'conds': self.print_recover_cond_to_gloss(gold)
                }
                # error_case['exe result']=self.correct['exe']

                error_case['BIO'] = [(x.item(), y) for x, y in zip(
                    list(self.BIO), list(gold['question']['words']))]
                error_case['BIO_col'] = self.BIO_col.tolist()

        return error_case
Ejemplo n.º 24
0
if __name__ == "__main__":
    lines = 0
    file_type = sys.argv[1]
    header_mapping_dictionary = defaultdict(lambda: [])
    for idx, line in enumerate(open(file_type + ".tables.jsonl")):
        tableJson = json.loads(line)
        if len(header_mapping_dictionary[tableJson["id"]]) == 0:
            header_mapping_dictionary[tableJson["id"]] = "{\"header\": " + str(
                tableJson["header"]) + "}"

    for line in open(file_type + "_tok.jsonl"):
        linesJson = json.loads(line)
        jsonQuery = linesJson["sql"]
        text = Query(jsonQuery["sel"],
                     jsonQuery["agg"],
                     conditions=jsonQuery["conds"]).__repr__()
        sql_pattern = '''SELECT\s*(?P<agg>^(\s+).*)?\s*(?P<select_col_1>^(\s+).*)|(?P<select_col>^(\s+).*)\s*FROM\s*(?P<table_name>^(\s+).*)\s*WHERE\s*(?P<where_col>[\w]*)\s*(?P<where_op>[\S><&|+\-\%*!=]{1,2})\s*(?P<where_cond>[\w\s]*)'''
        #text = "SELECT col1 FROM tblName WHERE col2 op val"
        sql_pattern = '''SELECT\s+(?P<agg>COUNT|SUM|AVG|MAX|MIN)?\s+(?P<select_col>[^\s]+)\s+FROM\s+(?P<table_name>[^\s]+)\s*WHERE\s+(?P<where_col>[\w]*)\s+(?P<where_op>[\S><&|+\-\%*!=]{1,2})\s+(?P<where_cond>.*)'''
        matches = re.search(sql_pattern, text, re.IGNORECASE)

        index = text.find("table")

        first_part = text[0:index]
        second_part = text[index + 5:]

        if matches:
            agg = matches.group("agg").strip() if matches.group(
                "agg") else None
Ejemplo n.º 25
0
    opts = Options()
    fill_from_args(opts)

    for split in ['train', 'dev', 'test']:
        orig = os.path.join(opts.data_dir, f'{split}.jsonl')
        db_file = os.path.join(opts.data_dir, f'{split}.db')
        ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz")
        tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl")
        engine = DBEngine(db_file)
        exact_match = []
        with open(orig) as fs, write_open(ans_file) as fo:
            grades = []
            for ls in tqdm(fs, total=count_lines(orig)):
                eg = json.loads(ls)
                sql = eg['sql']
                qg = Query.from_dict(sql, ordered=False)
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
                assert isinstance(gold, list)
                #if len(gold) != 1:
                #    print(f'for {sql} : {gold}')
                eg['answer'] = gold
                eg['rowids'] = engine.execute_query_rowid(eg['table_id'],
                                                          qg,
                                                          lower=True)
                # CONSIDER: if it is not an agg query, somehow identify the particular cell
                fo.write(json.dumps(eg) + '\n')

        convert(jsonl_lines(ans_file),
                jsonl_lines(tbl_file),
                os.path.join(opts.data_dir, f"{split}_agg.jsonl.gz"),
                skip_aggregation=False)
Ejemplo n.º 26
0
    parser.add_argument("--db_file")
    parser.add_argument("--pred_file")
    parser.add_argument("--ordered", action='store_true')
    args = parser.parse_args()

    engine = DBEngine(args.db_file)
    ex_acc_list = []
    lf_acc_list = []
    with open(args.source_file) as sf, open(args.pred_file) as pf:
        for source_line, pred_line in tqdm(zip(sf, pf), total=count_lines(args.source_file)):
            # line별 정답과 예측 샘플 가져오기
            gold_example = json.loads(source_line)
            pred_example = json.loads(pred_line)

            # 정답 샘플 lf, ex 구하기
            lf_gold_query = Query.from_dict(gold_example['sql'], ordered=args.ordered)
            ex_gold = engine.execute_query(gold_example['table_id'], lf_gold_query, lower=True)

            # error가 아닌 경우 예측 샘플 lf, ex 구하기
            lf_pred_query = None
            ex_pred = pred_example.get('error', None)
            if not ex_pred:
                try:
                    lf_pred_query = Query.from_dict(pred_example['query'], ordered=args.ordered)
                    ex_pred = engine.execute_query(gold_example['table_id'], lf_pred_query, lower=True)
                except Exception as e:
                    ex_pred = repr(e)

            # lf, ex의 gold, pred 매칭결과 구하기
            ex_acc_list.append(ex_pred == ex_gold)
            lf_acc_list.append(lf_pred_query == lf_gold_query) # query의 __eq__를 호출
Ejemplo n.º 27
0
import pytest

from lib.query import Query, LogicalOperator
from tests.test_lib.conftest import does_not_raise


@pytest.mark.parametrize('queries, expected_result, expected_exception', (
    (
        [
            Query(
                field_name='semester',
                value='1',
                operator=LogicalOperator.OR,
            ),
        ],
        {1, 2, 3},
        does_not_raise(),
    ),
    (
        [
            Query(
                field_name='last_name',
                value='one',
                operator=LogicalOperator.OR,
            ),
            Query(
                field_name='last_name',
                value='two',
                operator=LogicalOperator.OR,
            ),
        ],
Ejemplo n.º 28
0
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lib.query import Query
from lib.dbengine import DBEngine
import pytorch

if __name__ == '__main__':
    for split in ['train', 'dev', 'test']:
        print('checking {}'.format(split))
        engine = DBEngine('data/{}.db'.format(split))
        n_lines = 0
        with open('data/{}.jsonl'.format(split)) as f:
            for l in f:
                n_lines += 1
        with open('data/{}.jsonl'.format(split)) as f:
            for l in tqdm(f, total=n_lines):
                d = json.loads(l)
                query = Query.from_dict(d['sql'])

                # make sure it's executable
                result = engine.execute_query(d['table_id'], query)
                if result:
                    for a, b, c in d['sql']['conds']:
                        if str(c).lower() not in d['question'].lower():
                            raise Exception(
                                'Could not find condition {} in question {} for query {}'
                                .format(c, d['question'], query))
                else:
                    raise Exception(
                        'Query {} did not execute to a valid result'.format(
                            query))
Ejemplo n.º 29
0
            line_split = line.split('|')
            source_idx = int(line_split[-1])
            if len(line_split[:-1]) > 1:
                seq_in = '|'.join(line_split[:-1])
            else:
                seq_in = line_split[0]
            json_line = {}
            json_line['sample_id'] = source_idx
            json_line['query'] = ""
            json_line['error'] = ""
            seq_in = seq_in.decode('utf-8')
            seq_in = seq_in.split(' ')
            if '_EOS' in seq_in:
                seq_in = seq_in[:seq_in.index('_EOS')]
            json_line['seq'] = ' '.join(seq_in)
            sj = json.loads(data_source[source_idx])
            try:
                output_with_gloss = get_output_with_gloss(
                    seq_in, sj['seq_input'])
                q = Query.from_sequence(output_with_gloss,
                                        sj['table']).to_dict()
            except Exception as e:
                json_line['error'] = repr(e)
            else:
                json_line['query'] = q
            orig_list.append(json_line)
    sortedlist = sorted(orig_list, key=lambda k: k['sample_id'])
    with open(args.dout, 'wt') as fo:
        for item in sortedlist:
            fo.write(json.dumps(item) + '\n')