def is_valid_example(e): if not all([h['words'] for h in e['table']['header']]): return False headers = [detokenize(h).lower() for h in e['table']['header']] if len(headers) != len(set(headers)): return False input_vocab = set(e['seq_input']['words']) for w in e['seq_output']['words']: if w not in input_vocab: print('query word "{}" is not in input vocabulary.\n{}'.format(w, e['seq_input']['words'])) return False input_vocab = set(e['question']['words']) for col, op, cond in e['query']['conds']: for w in cond['words']: if w not in input_vocab: print('cond word "{}" is not in input vocabulary.\n{}'.format(w, e['question']['words'])) return False return True
def annotate_example(example, table): ann = {'table_id': example['table_id']} ann['question'] = annotate(example['question']) ann['table'] = { 'header': [annotate(h) for h in table['header']], } ann['query'] = sql = copy.deepcopy(example['sql']) for c in ann['query']['conds']: c[-1] = annotate(str(c[-1])) q1 = 'SYMSELECT SYMAGG {} SYMCOL {}'.format(Query.agg_ops[sql['agg']], table['header'][sql['sel']]) q2 = [ 'SYMCOL {} SYMOP {} SYMCOND {}'.format(table['header'][col], Query.cond_ops[op], detokenize(cond)) for col, op, cond in sql['conds'] ] if q2: q2 = 'SYMWHERE ' + ' SYMAND '.join(q2) + ' SYMEND' else: q2 = 'SYMEND' inp = 'SYMSYMS {syms} SYMAGGOPS {aggops} SYMCONDOPS {condops} SYMTABLE {table} SYMQUESTION {question} SYMEND'.format( syms=' '.join(['SYM' + s for s in Query.syms]), table=' '.join(['SYMCOL ' + s for s in table['header']]), question=example['question'], aggops=' '.join([s for s in Query.agg_ops]), condops=' '.join([s for s in Query.cond_ops]), ) ann['seq_input'] = annotate(inp) out = '{q1} {q2}'.format(q1=q1, q2=q2) if q2 else q1 ann['seq_output'] = annotate(out) ann['where_output'] = annotate(q2) assert 'symend' in ann['seq_output']['words'] assert 'symend' in ann['where_output']['words'] return ann
def from_tokenized_dict(cls, d): conds = [] for col, op, val in d['conds']: conds.append([col, op, detokenize(val)]) return cls(d['sel'], d['agg'], conds)
def from_generated_dict(cls, d): conds = [] for col, op, val in d['conds']: end = len(val['words']) conds.append([col, op, detokenize(val)]) return cls(d['sel'], d['agg'], conds)
def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True): sequence = deepcopy(sequence) if 'symend' in sequence['words']: end = sequence['words'].index('symend') for k, v in sequence.items(): sequence[k] = v[:end] terms = [{ 'gloss': g, 'word': w, 'after': a } for g, w, a in zip(sequence['gloss'], sequence['words'], sequence['after'])] headers = [detokenize(h) for h in table['header']] # lowercase everything and truncate sequence if lowercase: headers = [h.lower() for h in headers] for i, t in enumerate(terms): for k, v in t.items(): t[k] = v.lower() headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers] def find_column(name): return headers_no_whitespcae.index(re.sub(re_whitespace, '', name)) def flatten(tokens): ret = {'words': [], 'after': [], 'gloss': []} for t in tokens: ret['words'].append(t['word']) ret['after'].append(t['after']) ret['gloss'].append(t['gloss']) return ret where_index = [ i for i, t in enumerate(terms) if t['word'] == 'symwhere' ] where_index = where_index[0] if where_index else len(terms) where_terms = terms[where_index + 1:] # get conditions conditions = [] while where_terms: t = where_terms.pop(0) flat = flatten(where_terms) if t['word'] != 'symcol': raise Exception('Missing conditional column {}'.format( flat['words'])) try: op_index = flat['words'].index('symop') col_tokens = flatten(where_terms[:op_index]) except Exception as e: raise Exception('Missing conditional operator {}'.format( flat['words'])) cond_op = where_terms[op_index + 1]['word'] try: cond_op = cls.cond_ops.index(cond_op.upper()) except Exception as e: raise Exception('Invalid cond op {}'.format(cond_op)) try: cond_col = find_column(detokenize(col_tokens)) except Exception as e: raise Exception('Cannot find conditional column {}'.format( col_tokens['words'])) try: val_index = flat['words'].index('symcond') except Exception as e: raise Exception('Cannot find conditional value {}'.format( flat['words'])) where_terms = where_terms[val_index + 1:] flat = flatten(where_terms) val_end_index = flat['words'].index( 'symand') if 'symand' in flat['words'] else len(where_terms) cond_val = detokenize(flatten(where_terms[:val_end_index])) conditions.append([cond_col, cond_op, cond_val]) where_terms = where_terms[val_end_index + 1:] q = cls(agg_col, agg_op, conditions) return q