Example #1
0
def save_2index():
    grammar = Grammar('sql_simple_transition_2.bnf')
    #this shouldn't be necessary but it is
    grammar2 = Grammar('sql_simple_transition_2.bnf')

    sellist_gram = grammar.get_subgrammar('<sellist>')
    assert '<modifyop>' in grammar.gram_keys
    modifyopgram = grammar2.get_subgrammar('<modifyop>')
    dontuse = ['<start>']
    db_dict = []
    _, gram2ind = get_grammar2index(sellist_gram, db_dict, dontuse=dontuse)
    with open('grammar2index_grammar_2_sellist.json', 'w') as f:
        json.dump(gram2ind, f)
    _, gram2ind = get_grammar2index(modifyopgram, db_dict, dontuse=dontuse)
    with open('grammar2index_grammar_2_modifyop.json', 'w') as f:
        json.dump(gram2ind, f)
    _, w2ind = get_word2index(grammar, db_dict, dontuse=dontuse)
    with open('terminals2index_grammar_2.json', 'w') as f:
        json.dump(w2ind, f)
    with open('spider_tables_lowercase.json', 'r') as f:
        spider_db = json.loads(f.read())
    tab2ind = {}
    col2ind = {}
    allcols_db2ind = {}
    for db in spider_db.keys():
        dbp = db.lower()
        tab2ind[dbp] = {
            tab.lower(): i
            for i, tab in enumerate(spider_db[db].keys())
        }
        all_cols = list(
            set([
                c.lower() for t in spider_db[db].keys()
                for c in spider_db[db][t]
            ]))
        allcols_db2ind[dbp] = {c.lower(): i for i, c in enumerate(all_cols)}
        col2ind[dbp] = {}
        for tab in spider_db[db].keys():
            tabp = tab.lower()
            col2ind[dbp][tabp] = {
                col.lower(): i
                for i, col in enumerate(spider_db[db][tab])
            }
    with open('spider_tab2index.json', 'w') as f:
        json.dump(tab2ind, f)
    with open('spider_col2index.json', 'w') as f:
        json.dump(col2ind, f)

    with open('spider_db_cols2ind.json', 'w') as f:
        json.dump(allcols_db2ind, f)
Example #2
0
def resave_data(gfile='sql_simple_transition_2.bnf'):
    cols = [0, 2]
    grm = Grammar(gfile)
    assert '<modifyop>' in grm.gr
    grm_terms = grm.terminal_toks
    #     grm_terms1 = [g.upper() for g in grm_terms if '[' not in g]
    #     grm_terms = [g for g in grm_terms if '[' in g]
    #     grm_terms.extend(grm_terms1)
    td = pd.read_pickle(trdf)
    te = pd.read_pickle(tedf)
    td.to_pickle(trdf + '.bak')
    te.to_pickle(tedf + '.bak')

    for i, c in enumerate(cols):
        if i == 0:
            continue
        grm1 = Grammar(gfile)
        assert '<modifyop>' in grm1.gr
        max_ = TrainableRepresentation(grm1,
                                       x_dim=1,
                                       start=dnames[c],
                                       return_log=False).max_length
        numtr, numte = 0, 0
        td['drop_{}'.format(i)] = td[dnames[c]].apply(
            lambda x: check_(x, max_, grm_terms))
        te['drop_{}'.format(i)] = te[dnames[c]].apply(
            lambda x: check_(x, max_, grm_terms))
    print('Before shapes {}, {}'.format(td.shape, te.shape))
    td = td[~td['drop_0']]
    td = td[~td['drop_1']]
    te = te[~te['drop_0']]
    te = te[~te['drop_1']]
    del td['drop_0']
    del td['drop_1']
    del te['drop_0']
    del te['drop_1']
    print('After shapes {}, {}'.format(td.shape, te.shape))

    td.to_pickle(trdf)
    te.to_pickle(tedf)
Example #3
0
    def __init__(self, fname, cuda=True):

        super().__init__()

        gr = Grammar(fname)
        self.gr = gr
        self.grammar = gr.graph
        self.terminals = gr.terminals
        self.terminal_toks = gr.terminal_toks
        self.ors = gr.ors
        self.or_loc = {l: i for i, l in enumerate(self.ors)}
        self.ands = gr.ands
        self.learners = self.create_learners(cuda=cuda)
Example #4
0
        #heng = torch.zeros(self.rnn_hid, 1, self.word_dim)
        #hsql = torch.zeros(self.rnn_hid, 1, self.embed_size)
        heng = torch.zeros(1, 1, self.rnn_hid)
        hsql = torch.zeros(1, 1, self.rnn_hid)

        #rets, _, inds, ccount = self.grammar_propagation(self.possibles, x, heng, hsql)

        return self.grammar_propagation(self.possibles, x, heng, hsql)

    # def get_tokens(self, probs):


if __name__ == '__main__':

    gram = Grammar('sql_simple_transition_2.bnf')
    st = '<modifyop>'
    rsl = RecursiveDecoder(gram,
                           spider_db,
                           hidden_size=256,
                           word_dim=1024,
                           rnn_hid=256,
                           start=st)
    x = torch.randn(1024, 1, 11)
    _, _, _, toks, ccount = rsl.forward(x)
    print(ccount)
    print(toks)
    print(rsl.get_sql(toks))

    # if os.path.exists('model/RecNN.pkl'):
    #     model = torch.load('model/RecNN.pkl').to(device)