Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--reference_file', type=str, required=True)
    parser.add_argument('--prediction_file', type=str, required=True)
    parser.add_argument('--error_file', type=str, default='')
    args = parser.parse_args()

    predictions = {}
    with open(args.prediction_file, "r", encoding='utf-8') as f:
        for line in f:
            jobj = json.loads(line)
            predictions[jobj['id']] = normalize_answers(jobj['predictions'])

    with open(args.reference_file, "r", encoding='utf-8') as f:
        reference = json.load(f)

    references = {}
    for index, data in enumerate(reference):
        references[str(index)] = normalize_answers(data)

    num_correct = 0
    sums = defaultdict(float)
    counts = defaultdict(float)
    error_qids = write_open(args.error_file) if args.error_file else None

    for key, gold_answer in references.items():
        pred_answer = predictions[key]
        is_correct = gold_answer == pred_answer
        answer_types = []
        if len(gold_answer) == 1:
            answer_types.append('single_answer')
        elif len(gold_answer) > 1:
            answer_types.append('multi_answer')
        else:
            answer_types.append('no_answer')

        if len(pred_answer) > 1:
            answer_types.append('predicted_multi_answer')
        elif len(pred_answer) == 0:
            answer_types.append('predicted_no_answer')
        else:
            answer_types.append('predicted_single_answer')

        if is_correct:
            num_correct += 1
            for at in answer_types:
                sums[at] += 1
        elif error_qids is not None:
            error_qids.write(f'{key}\n')
        for at in answer_types:
            counts[at] += 1

    print('Correct: ', num_correct, len(references),
          num_correct / float(len(references)))
    for at, count in counts.items():
        sm = sums[at]
        print(f'{at} = {sm/count}, over {count}')

    if error_qids is not None:
        error_qids.close()
Ejemplo n.º 2
0
def evaluate(args: SeqPairHypers, eval_dataset: SeqPairLoader, model):
    eval_dataset.reset(uneven_batches=True, files_per_dataloader=-1)
    loader = eval_dataset.get_dataloader()
    model.eval()
    with torch.no_grad(), write_open(os.path.join(args.output_dir, f'results{args.global_rank}.jsonl.gz')) as f:
        for batch in loader:
            ids = batch[0]
            inputs = eval_dataset.batch_dict(batch)
            if 'labels' in inputs:
                del inputs['labels']
            logits = model(**inputs)[0].detach().cpu().numpy()
            for id, pred in zip(ids, logits):
                assert type(id) == str
                pred = [float(p) for p in pred]
                assert len(pred) == args.num_labels
                assert all([type(p) == float for p in pred])
                f.write(json.dumps({'id': id, 'predictions': pred})+'\n')
Ejemplo n.º 3
0
def write_agg_classify(data_dir,
                       split,
                       *,
                       exclude_header=False,
                       cell_sep_token='*'):
    with write_open(os.path.join(data_dir,
                                 f'{split}_agg_classify.jsonl.gz')) as out:
        for line in jsonl_lines(os.path.join(data_dir,
                                             f'{split}_agg.jsonl.gz')):
            jobj = json.loads(line)
            if not exclude_header:
                agg_inst = {
                    'id': jobj['id'],
                    'text_a': jobj['question'],
                    'text_b': f' {cell_sep_token} '.join(jobj['header']),
                    'label': jobj['agg_index']
                }
            else:
                agg_inst = {
                    'id': jobj['id'],
                    'text': jobj['question'],
                    'label': jobj['agg_index']
                }
            out.write(json.dumps(agg_inst) + '\n')
Ejemplo n.º 4
0
def main():
    opts = Options()
    fill_from_args(opts)

    id2qinfo = defaultdict(QInfo)
    for line in jsonl_lines(opts.gt):
        jobj = json.loads(line)
        id2qinfo[jobj['id']].fill_from_gt(jobj, blind_gt=opts.blind_gt)

    sums = defaultdict(float)
    counts = defaultdict(float)
    for line in jsonl_lines(opts.agg_preds):
        jobj = json.loads(line)
        qid = jobj['id']
        qinfo = id2qinfo[qid]
        preds = np.array(jobj['predictions'], dtype=np.float32)
        predicted = np.argmax(preds)
        gt = qinfo.gt_agg_index
        qinfo.agg_pred = predicted
        qinfo.agg_confs = preds
        correct = 1 if predicted == gt else 0
        counts[f'accuracy_{gt}'] += 1
        sums[f'accuracy_{gt}'] += correct
        counts[f'accuracy'] += 1
        sums[f'accuracy'] += correct
    if not opts.blind_gt:
        metric_names = list(sums.keys())
        metric_names.sort()
        for n in metric_names:
            print(f'{n} = {sums[n]/counts[n]} over {counts[n]}')

    for line in jsonl_lines(opts.cell_preds):
        jobj = json.loads(line)
        qid = jobj['qid']
        cell_preds = np.array(jobj['cells'], dtype=np.float32)
        qinfo = id2qinfo[qid]
        qinfo.cell_confs = cell_preds

    if opts.lookup_preds:
        for line in jsonl_lines(opts.lookup_preds):
            jobj = json.loads(line)
            qid = jobj['qid']
            qinfo = id2qinfo[qid]
            if qinfo.compute_agg_pred() == 0:
                cell_preds = np.array(jobj['cells'], dtype=np.float32)
                qinfo.cell_confs = cell_preds

    err_analysis_count = 0  # make non-zero to show cases where no threshold is possible
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    per_agg_thresholds = np.zeros(len(agg_ops), dtype=np.float32)
    if opts.use_threshold <= -1000:
        for qinfo in id2qinfo.values():
            qinfo.compute_threshold_range()
            if qinfo.threshold_range is None and qinfo.agg_pred != 0 and err_analysis_count > 0:
                err_analysis_count -= 1
                print(f'No threshold possible: {qinfo.question}\nagg {agg_ops[qinfo.gt_agg_index]} over {qinfo.col_gt},{qinfo.row_gt} yielding {qinfo.answers_gt}')
                print(f'Predicted agg {agg_ops[qinfo.agg_pred]} over {np.argmax(qinfo.cell_confs[0])} yielding {qinfo.agg_answers}')
                print([f'{h}:{qinfo.col_vals[hi] is not None}' for hi, h in enumerate(qinfo.header)])
                for ri, row in enumerate(qinfo.rows):
                    to_show = [f'{cell}:{qinfo.cell_confs[ri,ci]}' for ci, cell in enumerate(row)]
                    print(to_show)

        max_accuracy, best_threshold = find_best_threshold(id2qinfo.values())
        print(f'can get {max_accuracy} with threshold {best_threshold}')
        print(f'    {accuracy_at_threshold(id2qinfo.values(), best_threshold-0.1)} with threshold {best_threshold - 0.1}')
        print(f'    {accuracy_at_threshold(id2qinfo.values(), best_threshold+0.1)} with threshold {best_threshold + 0.1}')

        for ai in range(0, per_agg_thresholds.shape[0]):
            acc, bt = find_best_threshold(id2qinfo.values(), for_agg_index=ai)
            print(f'for {agg_ops[ai]} can get {acc} with threshold {bt}')
            per_agg_thresholds[ai] = bt
    else:
        best_threshold = opts.use_threshold
        per_agg_thresholds[:] = opts.use_threshold

    missed_lookup = 0
    lookup = 0
    non_lookup = 0
    lookup_by_agg = 0
    pred_out = write_open(opts.prediction_file) if opts.prediction_file else None
    for qinfo in id2qinfo.values():
        if qinfo.gt_agg_index == 0:
            lookup += 1
            if qinfo.agg_pred != 0 and qinfo.threshold_range is not None:
                #print(f'Aggregation gets right answer anyway? {qinfo.question}\nagg {agg_ops[qinfo.gt_agg_index]} over {qinfo.col_gt},{qinfo.row_gt} yielding {qinfo.answers_gt}')
                #print(f'Predicted agg {agg_ops[qinfo.agg_pred]} over {np.argmax(qinfo.cell_confs[0])} yielding {qinfo.agg_answers}')
                if qinfo.threshold_range[0] <= best_threshold <= qinfo.threshold_range[1]:
                    lookup_by_agg += 1
        else:
            non_lookup += 1
        if qinfo.gt_agg_index == 0 and qinfo.agg_pred != 0:
            missed_lookup += 1
        if pred_out is not None:
            this_threshold = per_agg_thresholds[qinfo.agg_pred] if opts.threshold_per_agg else best_threshold
            pred_out.write(json.dumps({
                'id':
                    qinfo.qid,
                'predictions':
                    qinfo.answer_at_threshold(this_threshold)
                                       })+'\n')
    if pred_out is not None:
        pred_out.close()
    if not opts.blind_gt:
        print(f'Lookup count = {lookup}, Non-lookup = {non_lookup}, '
              f'Lookup mispredicted as non-lookup = {missed_lookup}, but correct anyway = {lookup_by_agg}')
Ejemplo n.º 5
0
def convert(queries,
            tables,
            outfile,
            *,
            skip_aggregation=True,
            show_aggregation=False):
    """Creates examples for the training and dev sets."""
    tid2rows = dict()
    for line in tables:
        jobj = json.loads(line)
        tid = jobj['id']
        header = jobj['header']
        rows_orig = jobj['rows']
        rows = []
        for r in rows_orig:
            rows.append([str(cv) for cv in r])
        tid2rows[tid] = [[str(h) for h in header]] + rows
    with write_open(outfile) as out:
        for qid, line in enumerate(queries):
            jobj = json.loads(line)
            agg_index = jobj['sql']['agg']
            if skip_aggregation and agg_index != 0:  # skip aggregation queries
                continue
            table_id = jobj['table_id']
            rows = tid2rows[table_id]
            qtext = jobj['question']
            target_column = jobj['sql']['sel']
            condition_columns = [
                colndx for colndx, comp, val in jobj['sql']['conds']
            ]
            answers = jobj['answer']
            rowids = jobj['rowids'] if 'rowids' in jobj else None
            jobj = dict()
            jobj['id'] = f'{qid}'
            jobj['question'] = qtext
            jobj['header'] = rows[0]
            jobj['rows'] = rows[1:]
            jobj['target_column'] = target_column
            jobj['condition_columns'] = condition_columns
            jobj['table_id'] = table_id
            jobj['agg_index'] = agg_index
            if rowids is not None:
                jobj['target_rows'] = rowids
            if agg_index == 0:
                answers = [str(ans) for ans in answers]
                clean_answers = []
                for r in rows[1:]:
                    if cell_value_is_answer(r[target_column], answers):
                        clean_answers.append(r[target_column])
                if not clean_answers:
                    logger.info(f'no answers found! {answers} in {rows}')
                if len(clean_answers) != len(answers):
                    logger.info(
                        f'answers changed from {answers} to {clean_answers}')
                jobj['answers'] = list(set(clean_answers))
            else:
                jobj['answers'] = answers
            if show_aggregation and rowids and len(
                    rowids) > 1 and agg_index != 0:
                print(json.dumps(jobj))
            out.write(json.dumps(jobj) + '\n')
Ejemplo n.º 6
0
    class Options:
        def __init__(self):
            self.data_dir = ''

    opts = Options()
    fill_from_args(opts)

    for split in ['train', 'dev', 'test']:
        orig = os.path.join(opts.data_dir, f'{split}.jsonl')
        db_file = os.path.join(opts.data_dir, f'{split}.db')
        ans_file = os.path.join(opts.data_dir, f"{split}_ans.jsonl.gz")
        tbl_file = os.path.join(opts.data_dir, f"{split}.tables.jsonl")
        engine = DBEngine(db_file)
        exact_match = []
        with open(orig) as fs, write_open(ans_file) as fo:
            grades = []
            for ls in tqdm(fs, total=count_lines(orig)):
                eg = json.loads(ls)
                sql = eg['sql']
                qg = Query.from_dict(sql, ordered=False)
                gold = engine.execute_query(eg['table_id'], qg, lower=True)
                assert isinstance(gold, list)
                #if len(gold) != 1:
                #    print(f'for {sql} : {gold}')
                eg['answer'] = gold
                eg['rowids'] = engine.execute_query_rowid(eg['table_id'],
                                                          qg,
                                                          lower=True)
                # CONSIDER: if it is not an agg query, somehow identify the particular cell
                fo.write(json.dumps(eg) + '\n')
Ejemplo n.º 7
0
def main():
    opts = Options()
    fill_from_args(opts)

    # The escaped characters include: double quote (" => \") and backslash (\ => \\).
    # Newlines are represented as quoted line breaks.
    csv_base_dir = os.path.join(opts.wtq_dir, 'csv')
    id2rows = dict()
    for dir in os.listdir(csv_base_dir):
        full_dir = os.path.join(csv_base_dir, dir)
        for file in os.listdir(full_dir):
            with read_open(os.path.join(full_dir, file)) as csvfile:
                rows = []
                for row in csv.reader(csvfile,
                                      doublequote=False,
                                      escapechar='\\'):
                    rows.append(row)
                id2rows[f'csv/{dir}/{file}'] = rows

    # List items are separated by | (e.g., when|was|taylor|swift|born|?).
    # The following characters are escaped: newline (=> \n), backslash (\ => \\), and pipe (| => \p)
    # Note that pipes become \p so that doing x.split('|') will work.
    data_dir = os.path.join(opts.wtq_dir, 'data')
    with read_open(os.path.join(opts.id2split)) as ids_file:
        id2split = json.load(ids_file)
    splits = {
        split_name:
        write_open(os.path.join(data_dir, f'{split_name}_lookup.jsonl.gz'))
        for split_name in ['train', 'dev', 'test']
    }

    matched_by_substring = 0
    for infile in [
            'training.tsv', 'pristine-seen-tables.tsv',
            'pristine-unseen-tables.tsv'
    ]:
        for ndx, line in enumerate(read_lines(os.path.join(data_dir, infile))):
            parts = line.strip().split('\t')
            assert len(parts) == 4
            if ndx == 0:
                continue
            id = parts[0]
            if id not in id2split:
                continue
            split_name = id2split[id]
            query = tsv_unescape(parts[1])
            table_id = parts[2]
            answers = [tsv_unescape(p) for p in parts[3].split('|')]
            norm_answers = [normalize(answer) for answer in answers]
            all_rows = id2rows[table_id]
            header = all_rows[0]
            rows = all_rows[1:]

            # force 'answers' to contain only string equal matches to cell values
            target_columns = set()
            matched_answers = set()
            for rndx, row in enumerate(all_rows):
                for cndx, cell in enumerate(row):
                    if normalize(cell) in norm_answers:
                        target_columns.add(cndx)
                        matched_answers.add(cell)
            if opts.match_cell_substring < 1.0 and len(target_columns) == 0:
                for rndx, row in enumerate(all_rows):
                    for cndx, cell in enumerate(row):
                        ncell = normalize(cell)
                        if any([
                                answer in ncell and len(answer) / len(ncell) >=
                                opts.match_cell_substring
                                for answer in norm_answers
                        ]):
                            target_columns.add(cndx)
                            matched_answers.add(cell)
                if len(target_columns) > 0:
                    matched_by_substring += 1

            out = splits[split_name]
            if len(target_columns) == 0:
                print(f'{query} {answers} not found in table: \n{all_rows}')
            elif len(target_columns) > 1:
                print(
                    f'{query} {answers} multiple columns match answer: \n{all_rows}'
                )
            else:
                jobj = dict()
                jobj['id'] = id
                jobj['table_id'] = table_id
                jobj['question'] = query
                jobj['header'] = header
                jobj['target_column'] = list(target_columns)[0]
                answers = list(matched_answers)
                answers.sort()
                jobj['answers'] = answers
                jobj['rows'] = rows
                out.write(json.dumps(jobj) + '\n')

    print(f'matched by substring: {matched_by_substring}')
    for f in splits.values():
        f.close()
Ejemplo n.º 8
0

def get_pred_map(file):
    qid2pred = dict()
    with open(file, "r", encoding='utf-8') as f:
        for line in f:
            jobj = json.loads(line)
            qid = jobj['id']
            preds = jobj['cell_predictions']
            qid2pred[qid] = preds
    return qid2pred


pred_a = get_pred_map(args.pred_a)
pred_b = get_pred_map(args.pred_b)

whitespace = re.compile(r"\s+")

with read_open(args.gt) as gt_f, write_open(args.out) as out_f:
    for line in gt_f:
        jobj = json.loads(line)
        qid = str(jobj['id'])
        correct_a, ans_a = _evaluate_pred_list(pred_a[qid], jobj, answer_in_header=args.answer_in_header_a)
        correct_b, ans_b = _evaluate_pred_list(pred_b[qid], jobj, answer_in_header=args.answer_in_header_b)
        if correct_a and not correct_b:
            out_f.write('\n\n'+'='*80+'\n')
            out_f.write(tabulate(jobj['rows'], headers=jobj['header'], showindex=True)+'\n')
            out_f.write(jobj['question']+'\n')
            out_f.write(f"{whitespace.sub(' ', ans_a)} ({pred_a[qid][:10]})\n")
            out_f.write(f"{whitespace.sub(' ', ans_b)} ({pred_b[qid][:10]})\n")
def main():
    opts = Options()
    fill_from_args(opts)

    if opts.gt:
        id2gt = dict()
        lookup_subset = set()
        for line in jsonl_lines(opts.gt):
            jobj = json.loads(line)
            qid = jobj['id']
            tbl = jobj['rows']
            correct_cells = np.zeros((len(tbl), len(tbl[0])), dtype=np.bool)
            target_rows = jobj['target_rows'] if 'target_rows' in jobj else [
                jobj['target_row']
            ]
            target_cols = jobj[
                'target_columns'] if 'target_columns' in jobj else [
                    jobj['target_column']
                ]
            # TODO: also support getting correct cells from answers list
            for r in target_rows:
                for c in target_cols:
                    correct_cells[r, c] = True
            #if correct_cells.sum() == 0:
            #    print(f'No answer! {target_rows}, {target_cols}, {jobj["agg_index"]}')
            id2gt[qid] = correct_cells
            if 'agg_index' not in jobj or jobj['agg_index'] == 0:
                lookup_subset.add(qid)
    else:
        id2gt = None
        lookup_subset = None

    sums = defaultdict(float)
    counts = defaultdict(float)
    table_count = 0
    no_answer_count = 0
    col_predictions = gather_predictions(opts.col, softmax=opts.softmax)
    row_predictions = gather_predictions(opts.row, softmax=False)
    if opts.cell_prediction_output:
        cell_prediction_output = write_open(opts.cell_prediction_output)
    else:
        cell_prediction_output = None
    with write_open(opts.output) as out:
        for qid, col_preds in col_predictions.items():
            col_preds = to_ndarray(col_preds)
            row_preds = to_ndarray(row_predictions[qid])
            cell_preds = row_preds.reshape((-1, 1)) + col_preds.reshape(
                (1, -1))
            if id2gt is not None:
                correct_cells = id2gt[qid]
                if correct_cells.sum() > 0:
                    avg_p = average_precision_score(
                        y_true=correct_cells.reshape(-1),
                        y_score=cell_preds.reshape(-1))
                    sums['auc'] += avg_p
                    counts['auc'] += 1
                    if qid in lookup_subset:
                        sums['auc (lookup)'] += avg_p
                        counts['auc (lookup)'] += 1
                    else:
                        sums['auc (aggregation)'] += avg_p
                        counts['auc (aggregation)'] += 1
                else:
                    no_answer_count += 1
            table_count += 1
            out.write(
                json.dumps({
                    'qid': qid,
                    'cells': cell_preds.tolist(),
                    'rows': row_preds.tolist(),
                    'cols': col_preds.tolist()
                }) + '\n')
            if cell_prediction_output is not None:
                cell_prediction_output.write(
                    json.dumps({
                        'id':
                        qid,
                        'cell_predictions':
                        to_cell_predictions(cell_preds, top_k=20)
                    }) + '\n')
    if cell_prediction_output is not None:
        cell_prediction_output.close()
    for n, v in sums.items():
        print(f'{n} = {v/counts[n]}')
    print(f'Over {table_count} tables')
    if id2gt is not None and no_answer_count > 0:
        print(f'{no_answer_count} tables with no correct answer')
Ejemplo n.º 10
0
                else:
                    self.neg_count += 1
            if not is_pos:
                self.all_neg_count += 1
        return insts


class Options(Config):
    def __init__(self):
        super().__init__()
        self.input_dir = ''
        self.style = 'lookup'
        self.output_dir = ''


if __name__ == "__main__":
    opts = Options()
    fill_from_args(opts)
    for split in ['train', 'dev', 'test']:
        cols = ColumnConvert(opts)
        rows = RowConvert(opts)
        with write_open(os.path.join(opts.output_dir, split, 'row.jsonl.gz')) as rout, \
                write_open(os.path.join(opts.output_dir, split, 'col.jsonl.gz')) as cout:
            for line in jsonl_lines(
                    os.path.join(opts.input_dir,
                                 f'{split}_{opts.style}.jsonl.gz')):
                for r in rows.convert(line):
                    rout.write(json.dumps(r.to_dict()) + '\n')
                for c in cols.convert(line):
                    cout.write(json.dumps(c.to_dict()) + '\n')
Ejemplo n.º 11
0
def convert_queries(query_tsv, table_dir, out_dir):
    # NOTE: there are some tables with a large number of rows,
    # when creating instances, we limit the total number of rows in the table
    per_question_row_limit = 50
    dev_percent = 20
    test_percent = 20

    rand = random.Random(1234)
    tid2rows = dict()
    row_search_out = write_open(os.path.join(out_dir, 'row_pseudo_docs.jsonl'))
    print(f'Table\tRow Count\tHeader')
    for subdir in ['auto', 'monarch', 'regents']:
        subdir_full = os.path.join(table_dir, subdir)
        for file in os.listdir(subdir_full):
            assert file[-4:] == '.tsv'
            table_id = subdir + '-' + file[:-4]
            all_rows = []
            with read_open(os.path.join(subdir_full, file)) as csvfile:
                for rndx, parts in enumerate(
                        csv.reader(csvfile, doublequote=False,
                                   delimiter='\t')):
                    row = [c.strip() for c in parts]
                    all_rows.append(row)
                    if rndx > 0:
                        pdoc = dict()
                        # pdoc['contents'] = table_id + ' ' + ' | '.join([h+' : '+c for h, c in zip(all_rows[0], row)])
                        pdoc['contents'] = table_id + ' ' + ' | '.join(row)
                        pdoc['id'] = table_id + ':' + str(rndx - 1)
                        row_search_out.write(json.dumps(pdoc) + '\n')
            tid2rows[table_id] = all_rows
            print(f'{table_id}\t{len(all_rows)-1}\t{all_rows[0]}')
    print('\n\n')

    split_files = []
    for split in [
            'dev_lookup.jsonl.gz', 'test_lookup.jsonl.gz',
            'train_lookup.jsonl.gz'
    ]:
        split_files.append(write_open(os.path.join(out_dir, split)))
    question_over_row_limit_count = 0
    with read_open(query_tsv) as csvfile:
        for ndx, parts in enumerate(
                csv.reader(csvfile, doublequote=False, delimiter='\t')):
            if ndx == 0:
                # QUESTION        QUESTION-ALIGNMENT      CHOICE 1        CHOICE 2        CHOICE 3        CHOICE 4
                # CORRECT CHOICE  RELEVANT TABLE  RELEVANT ROW    RELEVANT COL
                continue
            if len(parts) != 10:
                print(f'bad line: {parts}')
                exit(1)
            qid = f'q{ndx}'
            qtext = parts[0].strip()
            # the part of table[relevant_row] used to construct the question
            question_alignment = [int(c.strip()) for c in parts[1].split(',')]
            choices = [c.strip() for c in parts[2:6]]
            answer = choices[int(parts[6]) - 1]
            table_id = parts[7]
            target_row = int(parts[8]) - 1  # -1 for header
            target_column = int(parts[9])
            all_rows = tid2rows[table_id]
            header = all_rows[0]
            rows = all_rows[1:]
            if target_column in question_alignment:
                question_alignment.remove(target_column)
            # t_ans = re.sub(r'\W+', '', rows[target_row][target_column].lower())
            # q_ans = re.sub(r'\W+', '', answer.lower())
            # if t_ans not in q_ans and q_ans not in t_ans:
            #    print(f'{answer} != {rows[target_row][target_column]} in ({table_id}, {target_row}, {target_column})')
            answer = rows[target_row][target_column]
            if 0 < per_question_row_limit < len(rows):
                pos_row = rows[target_row]
                neg_rows = rows[:target_row] + rows[target_row + 1:]
                rand.shuffle(neg_rows)
                rows = neg_rows[:per_question_row_limit]
                target_row = rand.randint(0, len(rows) - 1)
                rows[target_row] = pos_row
                question_over_row_limit_count += 1
            jobj = dict()
            jobj['id'] = qid
            jobj['question'] = qtext
            jobj['header'] = header
            jobj['rows'] = rows
            jobj['target_column'] = target_column
            jobj['answers'] = [answer]
            jobj['table_id'] = table_id
            # extra
            jobj['target_row'] = target_row
            jobj[
                'choices'] = choices  # but note that these could fail to match any cell...
            jobj['condition_columns'] = question_alignment
            line = json.dumps(jobj) + '\n'
            # CONSIDER: this split is bad - we really should split on table rather than just randomly
            if ndx % 100 < dev_percent:
                split_files[0].write(line)
            elif ndx % 100 < dev_percent + test_percent:
                split_files[1].write(line)
            else:
                split_files[2].write(line)
    for split_file in split_files:
        split_file.close()
    print(f'{question_over_row_limit_count} questions over row limit')