Ejemplo n.º 1
0
    def __init__(self, word_emb, N_word, N_h=300, N_depth=2, 
                gpu=True, trainable_emb=False, 
                table_type="std", use_hs=True):
        super(SuperModel, self).__init__()
        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth
        self.trainable_emb = trainable_emb
        self.table_type = table_type
        self.use_hs = use_hs
        self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>']

        # word embedding layer
        self.embed_layer = WordEmbedding(word_emb, N_word, gpu,
                self.SQL_TOK, trainable=trainable_emb)

        # initial all modules
        self.multi_sql = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.multi_sql.eval()

        self.key_word = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,
                                        gpu=gpu, use_hs=use_hs)
        self.key_word.eval()

        self.col = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.col.eval()

        self.op = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.op.eval()

        self.agg = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.agg.eval()

        self.root_teminal = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.root_teminal.eval()

        self.des_asc = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.des_asc.eval()

        self.having = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs)
        self.having.eval()

        self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs)
        self.andor.eval()

        self.softmax = nn.Softmax(dim = 1) #dim=1
        self.CE = nn.CrossEntropyLoss()
        self.log_softmax = nn.LogSoftmax()
        self.mlsml = nn.MultiLabelSoftMarginLoss()
        self.bce_logit = nn.BCEWithLogitsLoss()
        self.sigm = nn.Sigmoid()
        if gpu:
            self.cuda()
        self.path_not_found = 0
Ejemplo n.º 2
0
                       map_location=map_to))

    elif args.train_component == "having":
        model = HavingPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=GPU,
                                use_hs=use_hs)
        model.load_state_dict(
            torch.load("{}/having_models.dump".format(SAVED_MODELS_FOLDER),
                       map_location=map_to))

    elif args.train_component == "andor":
        model = AndOrPredictor(N_word=N_word,
                               N_h=N_h,
                               N_depth=N_depth,
                               gpu=GPU,
                               use_hs=use_hs)
        model.load_state_dict(
            torch.load("{}/andor_models.dump".format(SAVED_MODELS_FOLDER),
                       map_location=map_to))

    # model = SQLNet(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0)
    print("finished building model")

    print_flag = False
    embed_layer = WordEmbedding(word_emb,
Ejemplo n.º 3
0
class SuperModel(nn.Module):
    def __init__(self,
                 word_emb,
                 N_word,
                 N_h=300,
                 N_depth=2,
                 gpu=True,
                 trainable_emb=False,
                 table_type="std",
                 use_hs=True):
        super(SuperModel, self).__init__()
        self.gpu = gpu
        self.N_h = N_h
        self.N_depth = N_depth
        self.trainable_emb = trainable_emb
        self.table_type = table_type
        self.use_hs = use_hs
        self.SQL_TOK = [
            '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'
        ]

        # word embedding layer
        self.embed_layer = WordEmbedding(word_emb,
                                         N_word,
                                         gpu,
                                         self.SQL_TOK,
                                         trainable=trainable_emb)

        # initial all modules
        self.multi_sql = MultiSqlPredictor(N_word=N_word,
                                           N_h=N_h,
                                           N_depth=N_depth,
                                           gpu=gpu,
                                           use_hs=use_hs)
        self.multi_sql.eval()

        self.key_word = KeyWordPredictor(N_word=N_word,
                                         N_h=N_h,
                                         N_depth=N_depth,
                                         gpu=gpu,
                                         use_hs=use_hs)
        self.key_word.eval()

        self.col = ColPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=gpu,
                                use_hs=use_hs)
        self.col.eval()

        self.op = OpPredictor(N_word=N_word,
                              N_h=N_h,
                              N_depth=N_depth,
                              gpu=gpu,
                              use_hs=use_hs)
        self.op.eval()

        self.agg = AggPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=gpu,
                                use_hs=use_hs)
        self.agg.eval()

        self.root_teminal = RootTeminalPredictor(N_word=N_word,
                                                 N_h=N_h,
                                                 N_depth=N_depth,
                                                 gpu=gpu,
                                                 use_hs=use_hs)
        self.root_teminal.eval()

        self.des_asc = DesAscLimitPredictor(N_word=N_word,
                                            N_h=N_h,
                                            N_depth=N_depth,
                                            gpu=gpu,
                                            use_hs=use_hs)
        self.des_asc.eval()

        self.having = HavingPredictor(N_word=N_word,
                                      N_h=N_h,
                                      N_depth=N_depth,
                                      gpu=gpu,
                                      use_hs=use_hs)
        self.having.eval()

        self.andor = AndOrPredictor(N_word=N_word,
                                    N_h=N_h,
                                    N_depth=N_depth,
                                    gpu=gpu,
                                    use_hs=use_hs)
        self.andor.eval()

        self.softmax = nn.Softmax()  #dim=1
        self.CE = nn.CrossEntropyLoss()
        self.log_softmax = nn.LogSoftmax()
        self.mlsml = nn.MultiLabelSoftMarginLoss()
        self.bce_logit = nn.BCEWithLogitsLoss()
        self.sigm = nn.Sigmoid()
        if gpu:
            self.cuda()
        self.path_not_found = 0

    def forward(self, q_seq, history, tables):
        # if self.part:
        #     return self.part_forward(q_seq,history,tables)
        # else:
        return self.full_forward(q_seq, history, tables)

    def full_forward(self, q_seq, history, tables):
        B = len(q_seq)
        # print("q_seq:{}".format(q_seq))
        # print("Batch size:{}".format(B))
        q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq)
        col_seq = to_batch_tables(tables, B, self.table_type)
        col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch(
            col_seq)

        mkw_emb_var = self.embed_layer.gen_word_list_embedding(
            ["none", "except", "intersect", "union"], (B))
        mkw_len = np.full(q_len.shape, 4, dtype=np.int64)
        kw_emb_var = self.embed_layer.gen_word_list_embedding(
            ["where", "group by", "order by"], (B))
        kw_len = np.full(q_len.shape, 3, dtype=np.int64)

        stack = Stack()
        stack.push(("root", None))
        history = [["root"]] * B
        andor_cond = ""
        has_limit = False
        # sql = {}
        current_sql = {}
        sql_stack = []
        idx_stack = []
        kw_stack = []
        kw = ""
        nested_label = ""
        has_having = False

        timeout = time.time(
        ) + 2  # set timer to prevent infinite recursion in SQL generation
        failed = False
        while not stack.isEmpty():
            if time.time() > timeout:
                failed = True
                break
            vet = stack.pop()
            # print(vet)
            hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history)
            if len(idx_stack) > 0 and stack.size() < idx_stack[-1]:
                # print("pop!!!!!!!!!!!!!!!!!!!!!!")
                idx_stack.pop()
                current_sql = sql_stack.pop()
                kw = kw_stack.pop()
                # current_sql = current_sql["sql"]
            # history.append(vet)
            # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size()))
            if isinstance(vet, tuple) and vet[0] == "root":
                if history[0][-1] != "root":
                    history[0].append("root")
                    hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                        history)
                if vet[1] != "original":
                    idx_stack.append(stack.size())
                    sql_stack.append(current_sql)
                    kw_stack.append(kw)
                else:
                    idx_stack.append(stack.size())
                    sql_stack.append(sql_stack[-1])
                    kw_stack.append(kw)
                if "sql" in current_sql:
                    current_sql["nested_sql"] = {}
                    current_sql["nested_label"] = nested_label
                    current_sql = current_sql["nested_sql"]
                elif isinstance(vet[1], dict):
                    vet[1]["sql"] = {}
                    current_sql = vet[1]["sql"]
                elif vet[1] != "original":
                    current_sql["sql"] = {}
                    current_sql = current_sql["sql"]
                # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size()))
                if vet[1] == "nested" or vet[1] == "original":
                    stack.push("none")
                    history[0].append("none")
                else:
                    score = self.multi_sql.forward(q_emb_var, q_len,
                                                   hs_emb_var, hs_len,
                                                   mkw_emb_var, mkw_len)
                    label = np.argmax(score[0].data.cpu().numpy())
                    label = SQL_OPS[label]
                    history[0].append(label)
                    stack.push(label)
                if label != "none":
                    nested_label = label

            elif vet in ('intersect', 'except', 'union'):
                stack.push(("root", "nested"))
                stack.push(("root", "original"))
                # history[0].append("root")
            elif vet == "none":
                score = self.key_word.forward(q_emb_var, q_len, hs_emb_var,
                                              hs_len, kw_emb_var, kw_len)
                kw_num_score, kw_score = [x.data.cpu().numpy() for x in score]
                # print("kw_num_score:{}".format(kw_num_score))
                # print("kw_score:{}".format(kw_score))
                num_kw = np.argmax(kw_num_score[0])
                kw_score = list(np.argsort(-kw_score[0])[:num_kw])
                kw_score.sort(reverse=True)
                # print("num_kw:{}".format(num_kw))
                for kw in kw_score:
                    stack.push(KW_OPS[kw])
                stack.push("select")
            elif vet in ("select", "orderBy", "where", "groupBy", "having"):
                kw = vet
                current_sql[kw] = []
                history[0].append(vet)
                stack.push(("col", vet))
                # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len)
                # label = score[0].data.cpu().numpy()
                # andor_cond = COND_OPS[label]
                # history.append("")
            # elif vet == "groupBy":
            #     score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,)
            elif isinstance(vet, tuple) and vet[0] == "col":
                # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size()))
                score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                         col_emb_var, col_len, col_name_len)
                col_num_score, col_score = [
                    x.data.cpu().numpy() for x in score
                ]
                col_num = np.argmax(col_num_score[0]) + 1  # double check
                cols = np.argsort(-col_score[0])[:col_num]
                # print(col_num)
                # print("col_num_score:{}".format(col_num_score))
                # print("col_score:{}".format(col_score))
                for col in cols:
                    if vet[1] == "where":
                        stack.push(("op", "where", col))
                    elif vet[1] != "groupBy":
                        stack.push(("agg", vet[1], col))
                    elif vet[1] == "groupBy":
                        history[0].append(index_to_column_name(col, tables))
                        current_sql[kw].append(
                            index_to_column_name(col, tables))
                #predict and or or when there is multi col in where condition
                if col_num > 1 and vet[1] == "where":
                    score = self.andor.forward(q_emb_var, q_len, hs_emb_var,
                                               hs_len)
                    label = np.argmax(score[0].data.cpu().numpy())
                    andor_cond = COND_OPS[label]
                    current_sql[kw].append(andor_cond)
                if vet[1] == "groupBy" and col_num > 0:
                    score = self.having.forward(
                        q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var,
                        col_len, col_name_len,
                        np.full(B, cols[0], dtype=np.int64))
                    label = np.argmax(score[0].data.cpu().numpy())
                    if label == 1:
                        has_having = (label == 1)
                        # stack.insert(-col_num,"having")
                        stack.push("having")
                # history.append(index_to_column_name(cols[-1], tables[0]))
            elif isinstance(vet, tuple) and vet[0] == "agg":
                history[0].append(index_to_column_name(vet[2], tables))
                if vet[1] not in ("having", "orderBy"):  #DEBUG-ed 20180817
                    try:
                        current_sql[kw].append(
                            index_to_column_name(vet[2], tables))
                    except Exception as e:
                        # print(e)
                        traceback.print_exc()
                        print("history:{},current_sql:{} stack:{}".format(
                            history[0], current_sql, stack.items))
                        print("idx_stack:{}".format(idx_stack))
                        print("sql_stack:{}".format(sql_stack))
                        exit(1)
                hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                    history)

                score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                         col_emb_var, col_len, col_name_len,
                                         np.full(B, vet[2], dtype=np.int64))
                agg_num_score, agg_score = [
                    x.data.cpu().numpy() for x in score
                ]
                agg_num = np.argmax(agg_num_score[0])  # double check
                agg_idxs = np.argsort(-agg_score[0])[:agg_num]
                # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs]))
                if len(agg_idxs) > 0:
                    history[0].append(AGG_OPS[agg_idxs[0]])
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append(AGG_OPS[agg_idxs[0]])
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2],
                                    AGG_OPS[agg_idxs[0]]))  #DEBUG-ed 20180817
                    else:
                        stack.push(
                            ("op", "having", vet[2], AGG_OPS[agg_idxs[0]]))
                for agg in agg_idxs[1:]:
                    history[0].append(index_to_column_name(vet[2], tables))
                    history[0].append(AGG_OPS[agg])
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append(
                            index_to_column_name(vet[2], tables))
                        current_sql[kw].append(AGG_OPS[agg])
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2], AGG_OPS[agg]))
                    else:
                        stack.push(("op", "having", vet[2], agg_idxs))
                if len(agg_idxs) == 0:
                    if vet[1] not in ("having", "orderBy"):
                        current_sql[kw].append("none_agg")
                    elif vet[1] == "orderBy":
                        stack.push(("des_asc", vet[2], "none_agg"))
                    else:
                        stack.push(("op", "having", vet[2], "none_agg"))
                # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs])
                # if vet[1] == "having":
                #     stack.push(("op","having",vet[2],agg_idxs))
                # if vet[1] == "orderBy":
                #     stack.push(("des_asc",vet[2],agg_idxs))
                # if vet[1] == "groupBy" and has_having:
                #     stack.push("having")
            elif isinstance(vet, tuple) and vet[0] == "op":
                if vet[1] == "where":
                    # current_sql[kw].append(index_to_column_name(vet[2], tables))
                    history[0].append(index_to_column_name(vet[2], tables))
                    hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(
                        history)

                score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len,
                                        col_emb_var, col_len, col_name_len,
                                        np.full(B, vet[2], dtype=np.int64))

                op_num_score, op_score = [x.data.cpu().numpy() for x in score]
                op_num = np.argmax(
                    op_num_score[0]
                ) + 1  # num_score 0 maps to 1 in truth, must have at least one op
                ops = np.argsort(-op_score[0])[:op_num]
                # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops])
                if op_num > 0:
                    history[0].append(NEW_WHERE_OPS[ops[0]])
                    if vet[1] == "having":
                        stack.push(("root_teminal", vet[2], vet[3], ops[0]))
                    else:
                        stack.push(("root_teminal", vet[2], ops[0]))
                    # current_sql[kw].append(NEW_WHERE_OPS[ops[0]])
                for op in ops[1:]:
                    history[0].append(index_to_column_name(vet[2], tables))
                    history[0].append(NEW_WHERE_OPS[op])
                    # current_sql[kw].append(index_to_column_name(vet[2], tables))
                    # current_sql[kw].append(NEW_WHERE_OPS[op])
                    if vet[1] == "having":
                        stack.push(("root_teminal", vet[2], vet[3], op))
                    else:
                        stack.push(("root_teminal", vet[2], op))
                # stack.push(("root_teminal",vet[2]))
            elif isinstance(vet, tuple) and vet[0] == "root_teminal":
                score = self.root_teminal.forward(
                    q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len,
                    col_name_len, np.full(B, vet[1], dtype=np.int64))

                label = np.argmax(score[0].data.cpu().numpy())
                label = ROOT_TERM_OPS[label]
                if len(vet) == 4:
                    current_sql[kw].append(index_to_column_name(
                        vet[1], tables))
                    current_sql[kw].append(vet[2])
                    current_sql[kw].append(NEW_WHERE_OPS[vet[3]])
                else:
                    # print("kw:{}".format(kw))
                    try:
                        current_sql[kw].append(
                            index_to_column_name(vet[1], tables))
                    except Exception as e:
                        # print(e)
                        traceback.print_exc()
                        print("history:{},current_sql:{} stack:{}".format(
                            history[0], current_sql, stack.items))
                        print("idx_stack:{}".format(idx_stack))
                        print("sql_stack:{}".format(sql_stack))
                        exit(1)
                    current_sql[kw].append(NEW_WHERE_OPS[vet[2]])
                if label == "root":
                    history[0].append("root")
                    current_sql[kw].append({})
                    # current_sql = current_sql[kw][-1]
                    stack.push(("root", current_sql[kw][-1]))
                else:
                    current_sql[kw].append("terminal")
            elif isinstance(vet, tuple) and vet[0] == "des_asc":
                current_sql[kw].append(index_to_column_name(vet[1], tables))
                current_sql[kw].append(vet[2])
                score = self.des_asc.forward(
                    q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len,
                    col_name_len, np.full(B, vet[1], dtype=np.int64))
                label = np.argmax(score[0].data.cpu().numpy())
                dec_asc, has_limit = DEC_ASC_OPS[label]
                history[0].append(dec_asc)
                current_sql[kw].append(dec_asc)
                current_sql[kw].append(has_limit)
        # print("{}".format(current_sql))

        if failed: return None
        print("history:{}".format(history[0]))
        if len(sql_stack) > 0:
            current_sql = sql_stack[0]
        # print("{}".format(current_sql))
        return current_sql

    def gen_col(self, col, table, table_alias_dict):
        colname = table["column_names_original"][col[2]][1]
        table_idx = table["column_names_original"][col[2]][0]
        if table_idx not in table_alias_dict:
            return colname
        return "T{}.{}".format(table_alias_dict[table_idx], colname)

    def gen_group_by(self, sql, kw, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql)):
            # if len(sql[i+1]) == 0:
            # if sql[i+1] == "none_agg":
            ret.append(self.gen_col(sql[i], table, table_alias_dict))
            # else:
            #     ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict)))
            # for agg in sql[i+1]:
            #     ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict)))
        return "{} {}".format(kw, ",".join(ret))

    def gen_select(self, sql, kw, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql), 2):
            # if len(sql[i+1]) == 0:
            if sql[i + 1] == "none_agg" or not isinstance(
                    sql[i + 1], basestring):  #DEBUG-ed 20180817
                ret.append(self.gen_col(sql[i], table, table_alias_dict))
            else:
                ret.append("{}({})".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)))
            # for agg in sql[i+1]:
            #     ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict)))
        return "{} {}".format(kw, ",".join(ret))

    def gen_where(self, sql, table, table_alias_dict):
        if len(sql) == 0:
            return ""
        start_idx = 0
        andor = "and"
        if isinstance(sql[0], basestring):
            start_idx += 1
            andor = sql[0]
        ret = []
        for i in range(start_idx, len(sql), 3):
            col = self.gen_col(sql[i], table, table_alias_dict)
            op = sql[i + 1]
            val = sql[i + 2]
            where_item = ""
            if val == "terminal":
                where_item = "{} {} '{}'".format(col, op, val)
            else:
                val = self.gen_sql(val, table)
                where_item = "{} {} ({})".format(col, op, val)
            if op == "between":
                #TODO temprarily fixed
                where_item += " and 'terminal'"
            ret.append(where_item)
        return "where {}".format(" {} ".format(andor).join(ret))

    def gen_orderby(self, sql, table, table_alias_dict):
        ret = []
        limit = ""
        if sql[-1] == True:
            limit = "limit 1"
        for i in range(0, len(sql), 4):
            if sql[i + 1] == "none_agg" or not isinstance(
                    sql[i + 1], basestring):  #DEBUG-ed 20180817
                ret.append("{} {}".format(
                    self.gen_col(sql[i], table, table_alias_dict), sql[i + 2]))
            else:
                ret.append("{}({}) {}".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict),
                    sql[i + 2]))
        return "order by {} {}".format(",".join(ret), limit)

    def gen_having(self, sql, table, table_alias_dict):
        ret = []
        for i in range(0, len(sql), 4):
            if sql[i + 1] == "none_agg":
                col = self.gen_col(sql[i], table, table_alias_dict)
            else:
                col = "{}({})".format(
                    sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))
            op = sql[i + 2]
            val = sql[i + 3]
            if val == "terminal":
                ret.append("{} {} '{}'".format(col, op, val))
            else:
                val = self.gen_sql(val, table)
                ret.append("{} {} ({})".format(col, op, val))
        return "having {}".format(",".join(ret))

    def find_shortest_path(self, start, end, graph):
        stack = [[start, []]]
        visited = set()
        while len(stack) > 0:
            ele, history = stack.pop()
            if ele == end:
                return history
            for node in graph[ele]:
                if node[0] not in visited:
                    stack.append((node[0], history + [(node[0], node[1])]))
                    visited.add(node[0])
        print("table {} table {}".format(start, end))
        # print("could not find path!!!!!{}".format(self.path_not_found))
        self.path_not_found += 1
        # return []
    def gen_from(self, candidate_tables, table):
        def find(d, col):
            if d[col] == -1:
                return col
            return find(d, d[col])

        def union(d, c1, c2):
            r1 = find(d, c1)
            r2 = find(d, c2)
            if r1 == r2:
                return
            d[r1] = r2

        ret = ""
        if len(candidate_tables) <= 1:
            if len(candidate_tables) == 1:
                ret = "from {}".format(
                    table["table_names_original"][list(candidate_tables)[0]])
            else:
                ret = "from {}".format(table["table_names_original"][0])
            #TODO: temporarily settings
            return {}, ret
        # print("candidate:{}".format(candidate_tables))
        table_alias_dict = {}
        uf_dict = {}
        for t in candidate_tables:
            uf_dict[t] = -1
        idx = 1
        graph = defaultdict(list)
        for acol, bcol in table["foreign_keys"]:
            t1 = table["column_names"][acol][0]
            t2 = table["column_names"][bcol][0]
            graph[t1].append((t2, (acol, bcol)))
            graph[t2].append((t1, (bcol, acol)))
            # if t1 in candidate_tables and t2 in candidate_tables:
            #     r1 = find(uf_dict,t1)
            #     r2 = find(uf_dict,t2)
            #     if r1 == r2:
            #         continue
            #     union(uf_dict,t1,t2)
            #     if len(ret) == 0:
            #         ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2],
            #                                                                       idx+1,idx,table["column_names_original"][acol][1],idx+1,
            #                                                                       table["column_names_original"][bcol][1])
            #         table_alias_dict[t1] = idx
            #         table_alias_dict[t2] = idx+1
            #         idx += 2
            #     else:
            #         if t1 in table_alias_dict:
            #             old_t = t1
            #             new_t = t2
            #             acol,bcol = bcol,acol
            #         elif t2 in table_alias_dict:
            #             old_t = t2
            #             new_t = t1
            #         else:
            #             ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx,
            #                                                                           table["table_names"][t2],
            #                                                                           idx + 1, idx,
            #                                                                           table["column_names_original"][acol][1],
            #                                                                           idx + 1,
            #                                                                           table["column_names_original"][bcol][1])
            #             table_alias_dict[t1] = idx
            #             table_alias_dict[t2] = idx + 1
            #             idx += 2
            #             continue
            #         ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1],
            #                                                        table_alias_dict[old_t],table["column_names_original"][bcol][1])
            #         table_alias_dict[new_t] = idx
            #         idx += 1
        # visited = set()
        candidate_tables = list(candidate_tables)
        start = candidate_tables[0]
        table_alias_dict[start] = idx
        idx += 1
        ret = "from {} as T1".format(table["table_names_original"][start])
        try:
            for end in candidate_tables[1:]:
                if end in table_alias_dict:
                    continue
                path = self.find_shortest_path(start, end, graph)
                prev_table = start
                if not path:
                    table_alias_dict[end] = idx
                    idx += 1
                    ret = "{} join {} as T{}".format(
                        ret,
                        table["table_names_original"][end],
                        table_alias_dict[end],
                    )
                    continue
                for node, (acol, bcol) in path:
                    if node in table_alias_dict:
                        prev_table = node
                        continue
                    table_alias_dict[node] = idx
                    idx += 1
                    ret = "{} join {} as T{} on T{}.{} = T{}.{}".format(
                        ret, table["table_names_original"][node],
                        table_alias_dict[node], table_alias_dict[prev_table],
                        table["column_names_original"][acol][1],
                        table_alias_dict[node],
                        table["column_names_original"][bcol][1])
                    prev_table = node
        except:
            traceback.print_exc()
            print("db:{}".format(table["db_id"]))
            # print(table["db_id"])
            return table_alias_dict, ret
        # if len(candidate_tables) != len(table_alias_dict):
        #     print("error in generate from clause!!!!!")
        return table_alias_dict, ret

    def gen_sql(self, sql, table):
        select_clause = ""
        from_clause = ""
        groupby_clause = ""
        orderby_clause = ""
        having_clause = ""
        where_clause = ""
        nested_clause = ""
        cols = {}
        candidate_tables = set()
        nested_sql = {}
        nested_label = ""
        parent_sql = sql
        # if "sql" in sql:
        #     sql = sql["sql"]
        if "nested_label" in sql:
            nested_label = sql["nested_label"]
            nested_sql = sql["nested_sql"]
            sql = sql["sql"]
        elif "sql" in sql:
            sql = sql["sql"]
        for key in sql:
            if key not in KW_WITH_COL:
                continue
            for item in sql[key]:
                if isinstance(item, tuple) and len(item) == 3:
                    if table["column_names"][item[2]][0] != -1:
                        candidate_tables.add(table["column_names"][item[2]][0])
        table_alias_dict, from_clause = self.gen_from(candidate_tables, table)
        ret = []
        if "select" in sql:
            select_clause = self.gen_select(sql["select"], "select", table,
                                            table_alias_dict)
            if len(select_clause) > 0:
                ret.append(select_clause)
            else:
                print("select not found:{}".format(parent_sql))
        else:
            print("select not found:{}".format(parent_sql))
        if len(from_clause) > 0:
            ret.append(from_clause)
        if "where" in sql:
            where_clause = self.gen_where(sql["where"], table,
                                          table_alias_dict)
            if len(where_clause) > 0:
                ret.append(where_clause)
        if "groupBy" in sql:  ## DEBUG-ed order
            groupby_clause = self.gen_group_by(sql["groupBy"], "group by",
                                               table, table_alias_dict)
            if len(groupby_clause) > 0:
                ret.append(groupby_clause)
        if "orderBy" in sql:
            orderby_clause = self.gen_orderby(sql["orderBy"], table,
                                              table_alias_dict)
            if len(orderby_clause) > 0:
                ret.append(orderby_clause)
        if "having" in sql:
            having_clause = self.gen_having(sql["having"], table,
                                            table_alias_dict)
            if len(having_clause) > 0:
                ret.append(having_clause)
        if len(nested_label) > 0:
            nested_clause = "{} {}".format(nested_label,
                                           self.gen_sql(nested_sql, table))
            if len(nested_clause) > 0:
                ret.append(nested_clause)
        return " ".join(ret)

    def check_acc(self, pred_sql, gt_sql):
        pass
    def __init__(self,
                 embeddings,
                 N_word,
                 hidden_dim,
                 num_layers,
                 gpu,
                 num_augmentation=10000):
        self.embeddings = embeddings
        self.having_predictor = HavingPredictor(N_word=N_word,
                                                hidden_dim=hidden_dim,
                                                num_layers=num_layers,
                                                gpu=gpu).eval()
        self.keyword_predictor = KeyWordPredictor(N_word=N_word,
                                                  hidden_dim=hidden_dim,
                                                  num_layers=num_layers,
                                                  gpu=gpu).eval()
        self.andor_predictor = AndOrPredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()
        self.desasc_predictor = DesAscLimitPredictor(N_word=N_word,
                                                     hidden_dim=hidden_dim,
                                                     num_layers=num_layers,
                                                     gpu=gpu).eval()
        self.op_predictor = OpPredictor(N_word=N_word,
                                        hidden_dim=hidden_dim,
                                        num_layers=num_layers,
                                        gpu=gpu).eval()
        self.col_predictor = ColPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.agg_predictor = AggPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.limit_value_predictor = LimitValuePredictor(N_word=N_word,
                                                         hidden_dim=hidden_dim,
                                                         num_layers=num_layers,
                                                         gpu=gpu).eval()
        self.distinct_predictor = DistinctPredictor(N_word=N_word,
                                                    hidden_dim=hidden_dim,
                                                    num_layers=num_layers,
                                                    gpu=gpu).eval()
        self.value_predictor = ValuePredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()

        def get_model_path(model='having',
                           batch_size=64,
                           epoch=50,
                           num_augmentation=num_augmentation,
                           name_postfix=''):
            return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt'

        try:
            self.having_predictor.load(get_model_path('having'))
            self.keyword_predictor.load(
                get_model_path('keyword',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='kw2'))
            self.andor_predictor.load(
                get_model_path('andor', batch_size=256, num_augmentation=0))
            self.desasc_predictor.load(get_model_path('desasc'))
            self.op_predictor.load(get_model_path('op',
                                                  num_augmentation=10000))
            self.col_predictor.load(
                get_model_path('column',
                               epoch=300,
                               num_augmentation=30000,
                               name_postfix='rep2aug'))
            self.distinct_predictor.load(
                get_model_path('distinct',
                               epoch=300,
                               num_augmentation=0,
                               name_postfix='dist2'))
            self.agg_predictor.load(get_model_path('agg', num_augmentation=0))
            self.limit_value_predictor.load(get_model_path('limitvalue'))
            self.value_predictor.load(
                get_model_path('value',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='val2'))
        except FileNotFoundError as ex:
            print(ex)
        self.current_keyword = ''
        self.sql = None
        self.gpu = gpu
        if gpu:
            self.embeddings = self.embeddings.cuda()
class SyntaxSQL():
    def __init__(self,
                 embeddings,
                 N_word,
                 hidden_dim,
                 num_layers,
                 gpu,
                 num_augmentation=10000):
        self.embeddings = embeddings
        self.having_predictor = HavingPredictor(N_word=N_word,
                                                hidden_dim=hidden_dim,
                                                num_layers=num_layers,
                                                gpu=gpu).eval()
        self.keyword_predictor = KeyWordPredictor(N_word=N_word,
                                                  hidden_dim=hidden_dim,
                                                  num_layers=num_layers,
                                                  gpu=gpu).eval()
        self.andor_predictor = AndOrPredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()
        self.desasc_predictor = DesAscLimitPredictor(N_word=N_word,
                                                     hidden_dim=hidden_dim,
                                                     num_layers=num_layers,
                                                     gpu=gpu).eval()
        self.op_predictor = OpPredictor(N_word=N_word,
                                        hidden_dim=hidden_dim,
                                        num_layers=num_layers,
                                        gpu=gpu).eval()
        self.col_predictor = ColPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.agg_predictor = AggPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.limit_value_predictor = LimitValuePredictor(N_word=N_word,
                                                         hidden_dim=hidden_dim,
                                                         num_layers=num_layers,
                                                         gpu=gpu).eval()
        self.distinct_predictor = DistinctPredictor(N_word=N_word,
                                                    hidden_dim=hidden_dim,
                                                    num_layers=num_layers,
                                                    gpu=gpu).eval()
        self.value_predictor = ValuePredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()

        def get_model_path(model='having',
                           batch_size=64,
                           epoch=50,
                           num_augmentation=num_augmentation,
                           name_postfix=''):
            return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt'

        try:
            self.having_predictor.load(get_model_path('having'))
            self.keyword_predictor.load(
                get_model_path('keyword',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='kw2'))
            self.andor_predictor.load(
                get_model_path('andor', batch_size=256, num_augmentation=0))
            self.desasc_predictor.load(get_model_path('desasc'))
            self.op_predictor.load(get_model_path('op',
                                                  num_augmentation=10000))
            self.col_predictor.load(
                get_model_path('column',
                               epoch=300,
                               num_augmentation=30000,
                               name_postfix='rep2aug'))
            self.distinct_predictor.load(
                get_model_path('distinct',
                               epoch=300,
                               num_augmentation=0,
                               name_postfix='dist2'))
            self.agg_predictor.load(get_model_path('agg', num_augmentation=0))
            self.limit_value_predictor.load(get_model_path('limitvalue'))
            self.value_predictor.load(
                get_model_path('value',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='val2'))
        except FileNotFoundError as ex:
            print(ex)
        self.current_keyword = ''
        self.sql = None
        self.gpu = gpu
        if gpu:
            self.embeddings = self.embeddings.cuda()

    def generate_select(self):
        self.current_keyword = 'select'
        self.generate_columns()

    def generate_where(self):
        self.current_keyword = 'where'
        self.generate_columns()

    def generate_ascdesc(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['having'][-1]])
        col_idx = self.sql.database.get_idx_from_column(column)
        ascdesc = self.desasc_predictor.predict(self.q_emb_var, self.q_len,
                                                hs_emb_var, hs_len,
                                                self.col_emb_var, self.col_len,
                                                self.col_name_len, col_idx)
        ascdesc = SQL_ORDERBY_OPS[int(ascdesc)]
        self.sql.ORDERBY_OP += [ascdesc]
        if 'LIMIT' in ascdesc:
            limit_value = self.limit_value_predictor.predict(
                self.q_emb_var, self.q_len, hs_emb_var, hs_len,
                self.col_emb_var, self.col_len, self.col_name_len, col_idx)[0]
            self.sql.LIMIT_VALUE = limit_value

    def generate_orderby(self):
        self.current_keyword = 'orderby'
        self.generate_columns()

    def generate_groupby(self):
        self.current_keyword = 'groupby'
        self.generate_columns()

    def generate_having(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['having'][-1]])
        col_idx = self.sql.database.get_idx_from_column(column)
        having = self.having_predictor.predict(self.q_emb_var, self.q_len,
                                               hs_emb_var, hs_len,
                                               self.col_emb_var, self.col_len,
                                               self.col_name_len, col_idx)
        if having:
            self.current_keyword = 'having'
            self.generate_columns()

    def generate_keywords(self):
        self.generate_select()
        KEYWORDS = [
            self.generate_where, self.generate_groupby, self.generate_orderby
        ]
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            history['keyword'])
        num_kw, kws = self.keyword_predictor.predict(self.q_emb_var,
                                                     self.q_len, hs_emb_var,
                                                     hs_len, self.kw_emb_var,
                                                     self.kw_len)
        if num_kw[0] == 0:
            return
        key_words = sorted(kws[0])
        for key_word in key_words:
            KEYWORDS[int(key_word)]()

    def generate_andor(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['andor'][-1]])
        andor = self.andor_predictor.predict(self.q_emb_var, self.q_len,
                                             hs_emb_var, hs_len)
        andor = SQL_COND_OPS[int(andor)]
        if self.current_keyword == 'where':
            self.sql.WHERE[-1].cond_op = andor
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].cond_op = andor

    def generate_op(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['op'][-1]])
        col_idx = self.sql.database.get_idx_from_column(column)
        op = self.op_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var,
                                       hs_len, self.col_emb_var, self.col_len,
                                       self.col_name_len, col_idx)
        op = SQL_OPS[int(op)]
        if self.current_keyword == 'where':
            self.sql.WHERE[-1].op = op
        else:
            self.sql.HAVING[-1].op = op
        return op

    def generate_distrinct(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['distinct'][-1]])
        col_idx = self.sql.database.get_idx_from_column(column)
        distinct = self.distinct_predictor.predict(self.q_emb_var, self.q_len,
                                                   hs_emb_var, hs_len,
                                                   self.col_emb_var,
                                                   self.col_len,
                                                   self.col_name_len, col_idx)
        distinct = SQL_DISTINCT_OP[int(distinct)]
        if self.current_keyword == 'select':
            self.sql.COLS[-1].distinct = distinct
        elif self.current_keyword == 'orderby':
            self.sql.ORDERBY[-1].distinct = ''
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].distinct = distinct

    def generate_agg(self, column, early_return=False, force_agg=False):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['agg'][-1]])
        col_idx = self.sql.database.get_idx_from_column(column)
        agg = self.agg_predictor.predict(self.q_emb_var,
                                         self.q_len,
                                         hs_emb_var,
                                         hs_len,
                                         self.col_emb_var,
                                         self.col_len,
                                         self.col_name_len,
                                         col_idx,
                                         force_agg=force_agg)
        agg = SQL_AGG[int(agg)]
        if early_return is True:
            return agg
        if self.current_keyword == 'select':
            self.sql.COLS[-1].agg = agg
        elif self.current_keyword == 'orderby':
            self.sql.ORDERBY[-1].agg = agg
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].agg = agg

    def generate_between(self, column):
        ban_prediction = None
        for i in range(2):
            history = self.sql.generate_history()
            hs_emb_var, hs_len = self.embeddings.get_history_emb(
                [history['value'][-1]])
            tokens = word_tokenize(str.lower(self.question))
            int_tokens = [
                text2int(token.replace('-', '').replace('.', '')).isdigit()
                for token in tokens
            ]
            num_tokens, start_index = self.value_predictor.predict(
                self.q_emb_var, self.q_len, hs_emb_var, hs_len,
                self.col_emb_var, self.col_len, self.col_name_len,
                ban_prediction, int_tokens)
            num_tokens, start_index = int(num_tokens[0]), int(start_index[0])
            try:
                value = ' '.join(tokens[start_index:start_index + num_tokens])
                if self.current_keyword == 'where':
                    if i == 0:
                        self.sql.WHERE[-1].value = value
                        ban_prediction = (num_tokens, start_index)
                    else:
                        self.sql.WHERE[-1].valueless = value
                elif self.current_keyword == 'having':
                    if i == 0:
                        self.sql.HAVING[-1].value = value
                        ban_prediction = (num_tokens, start_index)
                    else:
                        self.sql.HAVING[-1].valueless = value
            except Exception as e:
                print(e)

    def generate_value(self, column):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['value'][-1]])
        num_tokens, start_index = self.value_predictor.predict(
            self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var,
            self.col_len, self.col_name_len)
        num_tokens, start_index = int(num_tokens[0]), int(start_index[0])
        tokens = word_tokenize(str.lower(self.question))
        try:
            value = ' '.join(tokens[start_index:start_index + num_tokens])
            value = text2int(value)

            if self.current_keyword == 'where':
                self.sql.WHERE[-1].value = value
            elif self.current_keyword == 'having':
                self.sql.HAVING[-1].value = value
        except:
            pass

    def generate_columns(self):
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['col'][-1]])
        num_cols, cols = self.col_predictor.predict(self.q_emb_var, self.q_len,
                                                    hs_emb_var, hs_len,
                                                    self.col_emb_var,
                                                    self.col_len,
                                                    self.col_name_len)
        num_cols, cols = num_cols[0], cols[0]

        def exclude_all_from_columns():
            excluded_idx = [
                len(table.columns) for table in self.sql.database.tables
            ]
            _, cols_new = self.col_predictor.predict(self.q_emb_var,
                                                     self.q_len,
                                                     hs_emb_var,
                                                     hs_len,
                                                     self.col_emb_var,
                                                     self.col_len,
                                                     self.col_name_len,
                                                     exclude_idx=excluded_idx)
            return self.sql.database.get_column_from_idx(cols_new[0][0])

        for i, col in enumerate(cols):
            column = self.sql.database.get_column_from_idx(col)
            if self.current_keyword in ('where', 'having'):
                if self.current_keyword == 'where':
                    if column.column_name == '*':
                        column = exclude_all_from_columns()
                    self.sql.WHERE += [Condition(column)]
                else:
                    self.sql.HAVING += [Condition(column)]
                op = self.generate_op(column)
                if op == 'BETWEEN':
                    self.generate_between(column)
                else:
                    self.generate_value(column)
                if num_cols > 1 and i < (num_cols - 1):
                    self.generate_andor(column)
            if self.current_keyword in ('orderby', 'select', 'having'):
                force_agg = False
                if self.current_keyword == 'orderby':
                    self.sql.ORDERBY += [ColumnSelect(column)]
                    if column.column_name == '*' and self.generate_agg(
                            column, early_return=True) == '':
                        column = exclude_all_from_columns()
                        self.sql.ORDERBY[-1] = ColumnSelect(column)
                elif self.current_keyword == 'select':
                    force_agg = len(set(cols)) < len(cols)
                    self.sql.COLS += [ColumnSelect(column)]
                self.generate_agg(column, force_agg=force_agg)
                self.generate_distrinct(column)
            if self.current_keyword == 'groupby':
                if column.column_name == '*':
                    column = exclude_all_from_columns()
                self.sql.GROUPBY += [ColumnSelect(column)]
        if self.current_keyword == 'groupby' and len(cols) > 0:
            self.generate_having(column)
        if self.current_keyword == 'orderby':
            self.generate_ascdesc(column)

    def GetSQL(self, question, database):
        self.sql = SQLStatement(query=None, database=database)
        self.question = question
        self.q_emb_var, self.q_len = self.embeddings(question)
        columns = self.sql.database.to_list()
        columns_all_splitted = []
        for i, column in enumerate(columns):
            columns_tmp = []
            for word in column:
                columns_tmp.extend(word.split('_'))
            columns_all_splitted += [columns_tmp]
        self.col_emb_var, self.col_len, self.col_name_len = self.embeddings.get_columns_emb(
            [columns_all_splitted])
        _, num_cols_in_db, col_name_lens, embedding_dim = self.col_emb_var.shape
        self.col_emb_var = self.col_emb_var.reshape(num_cols_in_db,
                                                    col_name_lens,
                                                    embedding_dim)
        self.col_name_len = self.col_name_len.reshape(-1)
        self.kw_emb_var, self.kw_len = self.embeddings.get_history_emb(
            [['where', 'order by', 'group by']])
        self.generate_keywords()
        return self.sql
Ejemplo n.º 6
0
class SyntaxSQL():
    """
    Main class for the SyntaxSQL model. 
    This takes all the sub modules, and uses them to run a question through the syntax tree
    """
    def __init__(self,
                 embeddings,
                 N_word,
                 hidden_dim,
                 num_layers,
                 gpu,
                 num_augmentation=10000):
        self.embeddings = embeddings
        self.having_predictor = HavingPredictor(N_word=N_word,
                                                hidden_dim=hidden_dim,
                                                num_layers=num_layers,
                                                gpu=gpu).eval()
        self.keyword_predictor = KeyWordPredictor(N_word=N_word,
                                                  hidden_dim=hidden_dim,
                                                  num_layers=num_layers,
                                                  gpu=gpu).eval()
        self.andor_predictor = AndOrPredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()
        self.desasc_predictor = DesAscLimitPredictor(N_word=N_word,
                                                     hidden_dim=hidden_dim,
                                                     num_layers=num_layers,
                                                     gpu=gpu).eval()
        self.op_predictor = OpPredictor(N_word=N_word,
                                        hidden_dim=hidden_dim,
                                        num_layers=num_layers,
                                        gpu=gpu).eval()
        self.col_predictor = ColPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.agg_predictor = AggPredictor(N_word=N_word,
                                          hidden_dim=hidden_dim,
                                          num_layers=num_layers,
                                          gpu=gpu).eval()
        self.limit_value_predictor = LimitValuePredictor(N_word=N_word,
                                                         hidden_dim=hidden_dim,
                                                         num_layers=num_layers,
                                                         gpu=gpu).eval()
        self.distinct_predictor = DistinctPredictor(N_word=N_word,
                                                    hidden_dim=hidden_dim,
                                                    num_layers=num_layers,
                                                    gpu=gpu).eval()
        self.value_predictor = ValuePredictor(N_word=N_word,
                                              hidden_dim=hidden_dim,
                                              num_layers=num_layers,
                                              gpu=gpu).eval()

        def get_model_path(model='having',
                           batch_size=64,
                           epoch=50,
                           num_augmentation=num_augmentation,
                           name_postfix=''):
            return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt'

        try:
            self.having_predictor.load(get_model_path('having'))
            self.keyword_predictor.load(
                get_model_path('keyword',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='kw2'))
            self.andor_predictor.load(
                get_model_path('andor', batch_size=256, num_augmentation=0))
            self.desasc_predictor.load(get_model_path('desasc'))
            self.op_predictor.load(get_model_path('op',
                                                  num_augmentation=10000))
            self.col_predictor.load(
                get_model_path('column',
                               epoch=300,
                               num_augmentation=30000,
                               name_postfix='rep2aug'))
            self.distinct_predictor.load(
                get_model_path('distinct',
                               epoch=300,
                               num_augmentation=0,
                               name_postfix='dist2'))
            self.agg_predictor.load(get_model_path('agg', num_augmentation=0))
            self.limit_value_predictor.load(get_model_path('limitvalue'))
            self.value_predictor.load(
                get_model_path('value',
                               epoch=300,
                               num_augmentation=10000,
                               name_postfix='val2'))

        except FileNotFoundError as ex:
            print(ex)

        self.current_keyword = ''
        self.sql = None
        self.gpu = gpu

        if gpu:
            self.embeddings = self.embeddings.cuda()

    def generate_select(self):
        # All statements should start with a select statement
        self.current_keyword = 'select'
        self.generate_columns()

    def generate_where(self):
        self.current_keyword = 'where'
        self.generate_columns()

    def generate_ascdesc(self, column):
        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['having'][-1]])

        col_idx = self.sql.database.get_idx_from_column(column)

        ascdesc = self.desasc_predictor.predict(self.q_emb_var, self.q_len,
                                                hs_emb_var, hs_len,
                                                self.col_emb_var, self.col_len,
                                                self.col_name_len, col_idx)

        ascdesc = SQL_ORDERBY_OPS[int(ascdesc)]

        self.sql.ORDERBY_OP += [ascdesc]

        if 'LIMIT' in ascdesc:
            limit_value = self.limit_value_predictor.predict(
                self.q_emb_var, self.q_len, hs_emb_var, hs_len,
                self.col_emb_var, self.col_len, self.col_name_len, col_idx)[0]
            self.sql.LIMIT_VALUE = limit_value

    def generate_orderby(self):
        self.current_keyword = 'orderby'
        self.generate_columns()

    def generate_groupby(self):
        self.current_keyword = 'groupby'
        self.generate_columns()

    def generate_having(self, column):
        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['having'][-1]])

        col_idx = self.sql.database.get_idx_from_column(column)

        having = self.having_predictor.predict(self.q_emb_var, self.q_len,
                                               hs_emb_var, hs_len,
                                               self.col_emb_var, self.col_len,
                                               self.col_name_len, col_idx)
        if having:
            self.current_keyword = 'having'
            self.generate_columns()

    def generate_keywords(self):
        self.generate_select()

        KEYWORDS = [
            self.generate_where, self.generate_groupby, self.generate_orderby
        ]

        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            history['keyword'])

        num_kw, kws = self.keyword_predictor.predict(self.q_emb_var,
                                                     self.q_len, hs_emb_var,
                                                     hs_len, self.kw_emb_var,
                                                     self.kw_len)

        if num_kw[0] == 0:
            return

        # We want the keywords in the same order as much as possible
        # Keywords are added FIFO queue, so sort it
        key_words = sorted(kws[0])

        # Add other states to the list
        for key_word in key_words:
            KEYWORDS[int(key_word)]()

    def generate_andor(self, column):
        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['andor'][-1]])

        andor = self.andor_predictor.predict(self.q_emb_var, self.q_len,
                                             hs_emb_var, hs_len)
        andor = SQL_COND_OPS[int(andor)]

        if self.current_keyword == 'where':
            self.sql.WHERE[-1].cond_op = andor
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].cond_op = andor

    def generate_op(self, column):
        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['op'][-1]])

        col_idx = self.sql.database.get_idx_from_column(column)

        op = self.op_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var,
                                       hs_len, self.col_emb_var, self.col_len,
                                       self.col_name_len, col_idx)
        op = SQL_OPS[int(op)]

        # Pick the current clause from the current keyword
        if self.current_keyword == 'where':
            self.sql.WHERE[-1].op = op
        else:
            self.sql.HAVING[-1].op = op

        return op

    def generate_distrinct(self, column):
        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['distinct'][-1]])

        col_idx = self.sql.database.get_idx_from_column(column)

        distinct = self.distinct_predictor.predict(self.q_emb_var, self.q_len,
                                                   hs_emb_var, hs_len,
                                                   self.col_emb_var,
                                                   self.col_len,
                                                   self.col_name_len, col_idx)

        distinct = SQL_DISTINCT_OP[int(distinct)]

        if self.current_keyword == 'select':
            self.sql.COLS[-1].distinct = distinct
        elif self.current_keyword == 'orderby':
            self.sql.ORDERBY[-1].distinct = ''
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].distinct = distinct

    def generate_agg(self, column, early_return=False, force_agg=False):

        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['agg'][-1]])

        col_idx = self.sql.database.get_idx_from_column(column)

        agg = self.agg_predictor.predict(self.q_emb_var,
                                         self.q_len,
                                         hs_emb_var,
                                         hs_len,
                                         self.col_emb_var,
                                         self.col_len,
                                         self.col_name_len,
                                         col_idx,
                                         force_agg=force_agg)

        agg = SQL_AGG[int(agg)]

        if early_return is True:
            return agg

        if self.current_keyword == 'select':
            self.sql.COLS[-1].agg = agg
        elif self.current_keyword == 'orderby':
            self.sql.ORDERBY[-1].agg = agg
        elif self.current_keyword == 'having':
            self.sql.HAVING[-1].agg = agg

    def generate_between(self, column):
        ban_prediction = None

        # Make two predictions
        for i in range(2):

            # Get the history, from the current sql
            history = self.sql.generate_history()
            hs_emb_var, hs_len = self.embeddings.get_history_emb(
                [history['value'][-1]])
            tokens = word_tokenize(str.lower(self.question))

            # Create mask for integer tokens
            int_tokens = [
                text2int(token.replace('-', '').replace('.', '')).isdigit()
                for token in tokens
            ]

            num_tokens, start_index = self.value_predictor.predict(
                self.q_emb_var, self.q_len, hs_emb_var, hs_len,
                self.col_emb_var, self.col_len, self.col_name_len,
                ban_prediction, int_tokens)
            num_tokens, start_index = int(num_tokens[0]), int(start_index[0])

            try:
                value = ' '.join(tokens[start_index:start_index + num_tokens])

                if self.current_keyword == 'where':
                    if i == 0:
                        self.sql.WHERE[-1].value = value
                        ban_prediction = (num_tokens, start_index)
                    else:
                        self.sql.WHERE[-1].valueless = value

                elif self.current_keyword == 'having':
                    if i == 0:
                        self.sql.HAVING[-1].value = value
                        ban_prediction = (num_tokens, start_index)
                    else:
                        self.sql.HAVING[-1].valueless = value

            # The value might not exist in the question, so just ignore it
            except Exception as e:
                print(e)

    def generate_value(self, column):

        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['value'][-1]])

        num_tokens, start_index = self.value_predictor.predict(
            self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var,
            self.col_len, self.col_name_len)

        num_tokens, start_index = int(num_tokens[0]), int(start_index[0])
        tokens = word_tokenize(str.lower(self.question))

        try:
            value = ' '.join(tokens[start_index:start_index + num_tokens])
            value = text2int(value)

            if self.current_keyword == 'where':
                self.sql.WHERE[-1].value = value
            elif self.current_keyword == 'having':
                self.sql.HAVING[-1].value = value

        # The value might not exist in the question, so just ignore it
        except:
            pass

    def generate_columns(self):

        # Get the history, from the current sql
        history = self.sql.generate_history()
        hs_emb_var, hs_len = self.embeddings.get_history_emb(
            [history['col'][-1]])

        num_cols, cols = self.col_predictor.predict(self.q_emb_var, self.q_len,
                                                    hs_emb_var, hs_len,
                                                    self.col_emb_var,
                                                    self.col_len,
                                                    self.col_name_len)

        # Predictions are returned as lists, but it only has one element
        num_cols, cols = num_cols[0], cols[0]

        def exclude_all_from_columns():
            # Do not permit * as valid column in where/having clauses
            excluded_idx = [
                len(table.columns) for table in self.sql.database.tables
            ]

            _, cols_new = self.col_predictor.predict(self.q_emb_var,
                                                     self.q_len,
                                                     hs_emb_var,
                                                     hs_len,
                                                     self.col_emb_var,
                                                     self.col_len,
                                                     self.col_name_len,
                                                     exclude_idx=excluded_idx)

            return self.sql.database.get_column_from_idx(cols_new[0][0])

        for i, col in enumerate(cols):
            column = self.sql.database.get_column_from_idx(col)

            if self.current_keyword in ('where', 'having'):

                # Add the column to the corresponding clause
                if self.current_keyword == 'where':
                    if column.column_name == '*':
                        column = exclude_all_from_columns()

                    self.sql.WHERE += [Condition(column)]
                else:
                    self.sql.HAVING += [Condition(column)]

                # We need the value and comparison operation in where/having clauses
                op = self.generate_op(column)

                if op == 'BETWEEN':
                    self.generate_between(column)
                else:
                    self.generate_value(column)

                # If we predict multiple columns in where or having, we need to also predict and/or
                if num_cols > 1 and i < (num_cols - 1):
                    self.generate_andor(column)

            if self.current_keyword in ('orderby', 'select', 'having'):
                force_agg = False
                if self.current_keyword == 'orderby':
                    self.sql.ORDERBY += [ColumnSelect(column)]
                    if column.column_name == '*' and self.generate_agg(
                            column, early_return=True) == '':
                        column = exclude_all_from_columns()
                        self.sql.ORDERBY[-1] = ColumnSelect(column)

                elif self.current_keyword == 'select':
                    force_agg = len(set(cols)) < len(cols)
                    self.sql.COLS += [ColumnSelect(column)]

                # Each column should have an aggregator
                self.generate_agg(column, force_agg=force_agg)
                self.generate_distrinct(column)

            if self.current_keyword == 'groupby':
                if column.column_name == '*':
                    column = exclude_all_from_columns()

                self.sql.GROUPBY += [ColumnSelect(column)]

        if self.current_keyword == 'groupby' and len(cols) > 0:
            self.generate_having(column)
        if self.current_keyword == 'orderby':
            self.generate_ascdesc(column)

    def GetSQL(self, question, database):
        # Generate representation of the database in form of SQL clauses
        self.sql = SQLStatement(query=None, database=database)
        self.question = question

        self.q_emb_var, self.q_len = self.embeddings(question)

        columns = self.sql.database.to_list()

        # Get all columns from the database and split them
        columns_all_splitted = []
        for i, column in enumerate(columns):
            columns_tmp = []
            for word in column:
                columns_tmp.extend(word.split('_'))
            columns_all_splitted += [columns_tmp]

        # Get embedding for the columns and keywords
        self.col_emb_var, self.col_len, self.col_name_len = self.embeddings.get_columns_emb(
            [columns_all_splitted])
        _, num_cols_in_db, col_name_lens, embedding_dim = self.col_emb_var.shape

        self.col_emb_var = self.col_emb_var.reshape(num_cols_in_db,
                                                    col_name_lens,
                                                    embedding_dim)
        self.col_name_len = self.col_name_len.reshape(-1)

        self.kw_emb_var, self.kw_len = self.embeddings.get_history_emb(
            [['where', 'order by', 'group by']])

        # Start recursively generating the sql history starting with the keywords, select and so on.
        self.generate_keywords()

        return self.sql
Ejemplo n.º 7
0
def train_feedback(nlq, db_name, correct_query, toy, word_emb):
    """
    Arguments:
        nlq: english question (tokenization is done here) - get from Flask (User)
        db_name: name of the database the query targets - get from Flask (User)
        correct_query: the ground truth query supplied by the user(s) - get from Flask
        toy: uses a small example of word embeddings to debug faster
    """

    ITER = 21

    SAVED_MODELS_FOLDER = "saved_models"
    OUTPUT_PATH = "output_inference.txt"
    HISTORY_TYPE = "full"
    GPU_ENABLE = False
    TRAIN_EMB = False
    TABLE_TYPE = "std"
    DATA_ROOT = "generated_data"

    use_hs = True
    if HISTORY_TYPE == "no":
        HISTORY_TYPE = "full"
        use_hs = False
    """
    Model Hyperparameters
    """
    N_word = 300  # word embedding dimension
    B_word = 42  # 42B tokens in the Glove pretrained embeddings
    N_h = 300  # hidden size dimension
    N_depth = 2  #

    if toy:
        USE_SMALL = True
        # GPU=True
        GPU = GPU_ENABLE
        BATCH_SIZE = 20
    else:
        USE_SMALL = False
        # GPU=True
        GPU = GPU_ENABLE
        BATCH_SIZE = 64
    # TRAIN_ENTRY=(False, True, False)  # (AGG, SEL, COND)
    # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY
    learning_rate = 1e-4

    # GENERATE CORRECT QUERY DATASET
    table_data_path = "./data/spider/tables.json"
    table_dict = get_table_dict(table_data_path)
    train_data_path = "./data/spider/train_spider.json"
    train_data = json.load(open(train_data_path))
    sql = correct_query  #"SELECT name ,  country ,  age FROM singer ORDER BY age DESC"
    db_id = db_name  #"concert_singer"
    table_file = table_data_path  # "tables.json"

    schemas, db_names, tables = get_schemas_from_json(table_file)
    schema = schemas[db_id]
    table = tables[db_id]
    schema = Schema(schema, table)
    sql_label = get_sql(schema, sql)
    correct_query_data = {
        "multi_sql_dataset": [],
        "keyword_dataset": [],
        "col_dataset": [],
        "op_dataset": [],
        "agg_dataset": [],
        "root_tem_dataset": [],
        "des_asc_dataset": [],
        "having_dataset": [],
        "andor_dataset": []
    }
    parser_item_with_long_history(
        tokenize(nlq),  #item["question_toks"], 
        sql_label,  #item["sql"], 
        table_dict[db_name],  #table_dict[item["db_id"]], 
        [],
        correct_query_data)
    # print("\nCorrect query dataset: {}".format(correct_query_data))

    for train_component in TRAIN_COMPONENTS:
        print("\nTRAIN COMPONENT: {}".format(train_component))
        # Check if the compenent to be trained is an actual component
        if train_component not in TRAIN_COMPONENTS:
            print("Invalid train component")
            exit(1)
        """
        Read in the data
        """
        train_data = load_train_dev_dataset(train_component, "train",
                                            HISTORY_TYPE, DATA_ROOT)
        # print("train_data type: {}".format(type(train_data)))
        dev_data = load_train_dev_dataset(train_component, "dev", HISTORY_TYPE,
                                          DATA_ROOT)
        # sql_data, table_data, val_sql_data, val_table_data, \
        #         test_sql_data, test_table_data, \
        #         TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL)

        if GPU_ENABLE:
            map_to = "gpu"
        else:
            map_to = "cpu"

        # Selecting which Model to Train
        model = None
        if train_component == "multi_sql":
            model = MultiSqlPredictor(N_word=N_word,
                                      N_h=N_h,
                                      N_depth=N_depth,
                                      gpu=GPU,
                                      use_hs=use_hs)
            model.load_state_dict(
                torch.load(
                    "{}/multi_sql_models.dump".format(SAVED_MODELS_FOLDER),
                    map_location=map_to))

        elif train_component == "keyword":
            model = KeyWordPredictor(N_word=N_word,
                                     N_h=N_h,
                                     N_depth=N_depth,
                                     gpu=GPU,
                                     use_hs=use_hs)
            model.load_state_dict(
                torch.load(
                    "{}/keyword_models.dump".format(SAVED_MODELS_FOLDER),
                    map_location=map_to))

        elif train_component == "col":
            model = ColPredictor(N_word=N_word,
                                 N_h=N_h,
                                 N_depth=N_depth,
                                 gpu=GPU,
                                 use_hs=use_hs)
            model.load_state_dict(
                torch.load("{}/col_models.dump".format(SAVED_MODELS_FOLDER),
                           map_location=map_to))

        elif train_component == "op":
            model = OpPredictor(N_word=N_word,
                                N_h=N_h,
                                N_depth=N_depth,
                                gpu=GPU,
                                use_hs=use_hs)
            model.load_state_dict(
                torch.load("{}/op_models.dump".format(SAVED_MODELS_FOLDER),
                           map_location=map_to))

        elif train_component == "agg":
            model = AggPredictor(N_word=N_word,
                                 N_h=N_h,
                                 N_depth=N_depth,
                                 gpu=GPU,
                                 use_hs=use_hs)
            model.load_state_dict(
                torch.load("{}/agg_models.dump".format(SAVED_MODELS_FOLDER),
                           map_location=map_to))

        elif train_component == "root_tem":
            model = RootTeminalPredictor(N_word=N_word,
                                         N_h=N_h,
                                         N_depth=N_depth,
                                         gpu=GPU,
                                         use_hs=use_hs)
            model.load_state_dict(
                torch.load(
                    "{}/root_tem_models.dump".format(SAVED_MODELS_FOLDER),
                    map_location=map_to))

        elif train_component == "des_asc":
            model = DesAscLimitPredictor(N_word=N_word,
                                         N_h=N_h,
                                         N_depth=N_depth,
                                         gpu=GPU,
                                         use_hs=use_hs)
            model.load_state_dict(
                torch.load(
                    "{}/des_asc_models.dump".format(SAVED_MODELS_FOLDER),
                    map_location=map_to))

        elif train_component == "having":
            model = HavingPredictor(N_word=N_word,
                                    N_h=N_h,
                                    N_depth=N_depth,
                                    gpu=GPU,
                                    use_hs=use_hs)
            model.load_state_dict(
                torch.load("{}/having_models.dump".format(SAVED_MODELS_FOLDER),
                           map_location=map_to))

        elif train_component == "andor":
            model = AndOrPredictor(N_word=N_word,
                                   N_h=N_h,
                                   N_depth=N_depth,
                                   gpu=GPU,
                                   use_hs=use_hs)
            model.load_state_dict(
                torch.load("{}/andor_models.dump".format(SAVED_MODELS_FOLDER),
                           map_location=map_to))

        # model = SQLNet(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb)

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=0)
        print("finished build model")

        print_flag = False
        embed_layer = WordEmbedding(word_emb,
                                    N_word,
                                    gpu=GPU,
                                    SQL_TOK=SQL_TOK,
                                    trainable=TRAIN_EMB)

        print("start training")
        best_acc = 0.0
        for i in range(ITER):
            print('ITER %d @ %s' % (i + 1, datetime.datetime.now()))
            # arguments of epoch_train
            # model, optimizer, batch_size, component,embed_layer,data, table_type
            # print(' Loss = %s' % epoch_train(
            #                     model, optimizer, BATCH_SIZE,
            #                     args.train_component,
            #                     embed_layer,
            #                     train_data,
            #                     table_type=args.table_type))
            print('Total Loss = %s' %
                  epoch_feedback_train(model=model,
                                       optimizer=optimizer,
                                       batch_size=BATCH_SIZE,
                                       component=train_component,
                                       embed_layer=embed_layer,
                                       data=train_data,
                                       table_type=TABLE_TYPE,
                                       nlq=nlq,
                                       db_name=db_name,
                                       correct_query=correct_query,
                                       correct_query_data=correct_query_data))

            # Check improvement every 10 iterations
            if i % 10 == 0:
                acc = epoch_acc(model,
                                BATCH_SIZE,
                                train_component,
                                embed_layer,
                                dev_data,
                                table_type=TABLE_TYPE)
                if acc > best_acc:
                    best_acc = acc
                    print("Save model...")
                    torch.save(
                        model.state_dict(), SAVED_MODELS_FOLDER +
                        "/{}_models.dump".format(train_component))