def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): super(SuperModel, self).__init__() self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.trainable_emb = trainable_emb self.table_type = table_type self.use_hs = use_hs self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'] # word embedding layer self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, trainable=trainable_emb) # initial all modules self.multi_sql = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.multi_sql.eval() self.key_word = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.key_word.eval() self.col = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.col.eval() self.op = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.op.eval() self.agg = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.agg.eval() self.root_teminal = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.root_teminal.eval() self.des_asc = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.des_asc.eval() self.having = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.having.eval() self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.andor.eval() self.softmax = nn.Softmax(dim = 1) #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss() self.bce_logit = nn.BCEWithLogitsLoss() self.sigm = nn.Sigmoid() if gpu: self.cuda() self.path_not_found = 0
map_location=map_to)) elif args.train_component == "col": model = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/col_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif args.train_component == "op": model = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/op_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif args.train_component == "agg": model = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/agg_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to))
class SuperModel(nn.Module): def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): super(SuperModel, self).__init__() self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.trainable_emb = trainable_emb self.table_type = table_type self.use_hs = use_hs self.SQL_TOK = [ '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>' ] # word embedding layer self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, trainable=trainable_emb) # initial all modules self.multi_sql = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.multi_sql.eval() self.key_word = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.key_word.eval() self.col = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.col.eval() self.op = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.op.eval() self.agg = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.agg.eval() self.root_teminal = RootTeminalPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.root_teminal.eval() self.des_asc = DesAscLimitPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.des_asc.eval() self.having = HavingPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.having.eval() self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.andor.eval() self.softmax = nn.Softmax() #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss() self.bce_logit = nn.BCEWithLogitsLoss() self.sigm = nn.Sigmoid() if gpu: self.cuda() self.path_not_found = 0 def forward(self, q_seq, history, tables): # if self.part: # return self.part_forward(q_seq,history,tables) # else: return self.full_forward(q_seq, history, tables) def full_forward(self, q_seq, history, tables): B = len(q_seq) # print("q_seq:{}".format(q_seq)) # print("Batch size:{}".format(B)) q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq) col_seq = to_batch_tables(tables, B, self.table_type) col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch( col_seq) mkw_emb_var = self.embed_layer.gen_word_list_embedding( ["none", "except", "intersect", "union"], (B)) mkw_len = np.full(q_len.shape, 4, dtype=np.int64) kw_emb_var = self.embed_layer.gen_word_list_embedding( ["where", "group by", "order by"], (B)) kw_len = np.full(q_len.shape, 3, dtype=np.int64) stack = Stack() stack.push(("root", None)) history = [["root"]] * B andor_cond = "" has_limit = False # sql = {} current_sql = {} sql_stack = [] idx_stack = [] kw_stack = [] kw = "" nested_label = "" has_having = False timeout = time.time( ) + 2 # set timer to prevent infinite recursion in SQL generation failed = False while not stack.isEmpty(): if time.time() > timeout: failed = True break vet = stack.pop() # print(vet) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) if len(idx_stack) > 0 and stack.size() < idx_stack[-1]: # print("pop!!!!!!!!!!!!!!!!!!!!!!") idx_stack.pop() current_sql = sql_stack.pop() kw = kw_stack.pop() # current_sql = current_sql["sql"] # history.append(vet) # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size())) if isinstance(vet, tuple) and vet[0] == "root": if history[0][-1] != "root": history[0].append("root") hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) if vet[1] != "original": idx_stack.append(stack.size()) sql_stack.append(current_sql) kw_stack.append(kw) else: idx_stack.append(stack.size()) sql_stack.append(sql_stack[-1]) kw_stack.append(kw) if "sql" in current_sql: current_sql["nested_sql"] = {} current_sql["nested_label"] = nested_label current_sql = current_sql["nested_sql"] elif isinstance(vet[1], dict): vet[1]["sql"] = {} current_sql = vet[1]["sql"] elif vet[1] != "original": current_sql["sql"] = {} current_sql = current_sql["sql"] # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size())) if vet[1] == "nested" or vet[1] == "original": stack.push("none") history[0].append("none") else: score = self.multi_sql.forward(q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len) label = np.argmax(score[0].data.cpu().numpy()) label = SQL_OPS[label] history[0].append(label) stack.push(label) if label != "none": nested_label = label elif vet in ('intersect', 'except', 'union'): stack.push(("root", "nested")) stack.push(("root", "original")) # history[0].append("root") elif vet == "none": score = self.key_word.forward(q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len) kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] # print("kw_num_score:{}".format(kw_num_score)) # print("kw_score:{}".format(kw_score)) num_kw = np.argmax(kw_num_score[0]) kw_score = list(np.argsort(-kw_score[0])[:num_kw]) kw_score.sort(reverse=True) # print("num_kw:{}".format(num_kw)) for kw in kw_score: stack.push(KW_OPS[kw]) stack.push("select") elif vet in ("select", "orderBy", "where", "groupBy", "having"): kw = vet current_sql[kw] = [] history[0].append(vet) stack.push(("col", vet)) # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) # label = score[0].data.cpu().numpy() # andor_cond = COND_OPS[label] # history.append("") # elif vet == "groupBy": # score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,) elif isinstance(vet, tuple) and vet[0] == "col": # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size())) score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) col_num_score, col_score = [ x.data.cpu().numpy() for x in score ] col_num = np.argmax(col_num_score[0]) + 1 # double check cols = np.argsort(-col_score[0])[:col_num] # print(col_num) # print("col_num_score:{}".format(col_num_score)) # print("col_score:{}".format(col_score)) for col in cols: if vet[1] == "where": stack.push(("op", "where", col)) elif vet[1] != "groupBy": stack.push(("agg", vet[1], col)) elif vet[1] == "groupBy": history[0].append(index_to_column_name(col, tables)) current_sql[kw].append( index_to_column_name(col, tables)) #predict and or or when there is multi col in where condition if col_num > 1 and vet[1] == "where": score = self.andor.forward(q_emb_var, q_len, hs_emb_var, hs_len) label = np.argmax(score[0].data.cpu().numpy()) andor_cond = COND_OPS[label] current_sql[kw].append(andor_cond) if vet[1] == "groupBy" and col_num > 0: score = self.having.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, cols[0], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) if label == 1: has_having = (label == 1) # stack.insert(-col_num,"having") stack.push("having") # history.append(index_to_column_name(cols[-1], tables[0])) elif isinstance(vet, tuple) and vet[0] == "agg": history[0].append(index_to_column_name(vet[2], tables)) if vet[1] not in ("having", "orderBy"): #DEBUG-ed 20180817 try: current_sql[kw].append( index_to_column_name(vet[2], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) agg_num_score, agg_score = [ x.data.cpu().numpy() for x in score ] agg_num = np.argmax(agg_num_score[0]) # double check agg_idxs = np.argsort(-agg_score[0])[:agg_num] # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs])) if len(agg_idxs) > 0: history[0].append(AGG_OPS[agg_idxs[0]]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append(AGG_OPS[agg_idxs[0]]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg_idxs[0]])) #DEBUG-ed 20180817 else: stack.push( ("op", "having", vet[2], AGG_OPS[agg_idxs[0]])) for agg in agg_idxs[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(AGG_OPS[agg]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append( index_to_column_name(vet[2], tables)) current_sql[kw].append(AGG_OPS[agg]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg])) else: stack.push(("op", "having", vet[2], agg_idxs)) if len(agg_idxs) == 0: if vet[1] not in ("having", "orderBy"): current_sql[kw].append("none_agg") elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], "none_agg")) else: stack.push(("op", "having", vet[2], "none_agg")) # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs]) # if vet[1] == "having": # stack.push(("op","having",vet[2],agg_idxs)) # if vet[1] == "orderBy": # stack.push(("des_asc",vet[2],agg_idxs)) # if vet[1] == "groupBy" and has_having: # stack.push("having") elif isinstance(vet, tuple) and vet[0] == "op": if vet[1] == "where": # current_sql[kw].append(index_to_column_name(vet[2], tables)) history[0].append(index_to_column_name(vet[2], tables)) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) op_num_score, op_score = [x.data.cpu().numpy() for x in score] op_num = np.argmax( op_num_score[0] ) + 1 # num_score 0 maps to 1 in truth, must have at least one op ops = np.argsort(-op_score[0])[:op_num] # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops]) if op_num > 0: history[0].append(NEW_WHERE_OPS[ops[0]]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], ops[0])) else: stack.push(("root_teminal", vet[2], ops[0])) # current_sql[kw].append(NEW_WHERE_OPS[ops[0]]) for op in ops[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(NEW_WHERE_OPS[op]) # current_sql[kw].append(index_to_column_name(vet[2], tables)) # current_sql[kw].append(NEW_WHERE_OPS[op]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], op)) else: stack.push(("root_teminal", vet[2], op)) # stack.push(("root_teminal",vet[2])) elif isinstance(vet, tuple) and vet[0] == "root_teminal": score = self.root_teminal.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) label = ROOT_TERM_OPS[label] if len(vet) == 4: current_sql[kw].append(index_to_column_name( vet[1], tables)) current_sql[kw].append(vet[2]) current_sql[kw].append(NEW_WHERE_OPS[vet[3]]) else: # print("kw:{}".format(kw)) try: current_sql[kw].append( index_to_column_name(vet[1], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) current_sql[kw].append(NEW_WHERE_OPS[vet[2]]) if label == "root": history[0].append("root") current_sql[kw].append({}) # current_sql = current_sql[kw][-1] stack.push(("root", current_sql[kw][-1])) else: current_sql[kw].append("terminal") elif isinstance(vet, tuple) and vet[0] == "des_asc": current_sql[kw].append(index_to_column_name(vet[1], tables)) current_sql[kw].append(vet[2]) score = self.des_asc.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) dec_asc, has_limit = DEC_ASC_OPS[label] history[0].append(dec_asc) current_sql[kw].append(dec_asc) current_sql[kw].append(has_limit) # print("{}".format(current_sql)) if failed: return None print("history:{}".format(history[0])) if len(sql_stack) > 0: current_sql = sql_stack[0] # print("{}".format(current_sql)) return current_sql def gen_col(self, col, table, table_alias_dict): colname = table["column_names_original"][col[2]][1] table_idx = table["column_names_original"][col[2]][0] if table_idx not in table_alias_dict: return colname return "T{}.{}".format(table_alias_dict[table_idx], colname) def gen_group_by(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql)): # if len(sql[i+1]) == 0: # if sql[i+1] == "none_agg": ret.append(self.gen_col(sql[i], table, table_alias_dict)) # else: # ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_select(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql), 2): # if len(sql[i+1]) == 0: if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append(self.gen_col(sql[i], table, table_alias_dict)) else: ret.append("{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_where(self, sql, table, table_alias_dict): if len(sql) == 0: return "" start_idx = 0 andor = "and" if isinstance(sql[0], basestring): start_idx += 1 andor = sql[0] ret = [] for i in range(start_idx, len(sql), 3): col = self.gen_col(sql[i], table, table_alias_dict) op = sql[i + 1] val = sql[i + 2] where_item = "" if val == "terminal": where_item = "{} {} '{}'".format(col, op, val) else: val = self.gen_sql(val, table) where_item = "{} {} ({})".format(col, op, val) if op == "between": #TODO temprarily fixed where_item += " and 'terminal'" ret.append(where_item) return "where {}".format(" {} ".format(andor).join(ret)) def gen_orderby(self, sql, table, table_alias_dict): ret = [] limit = "" if sql[-1] == True: limit = "limit 1" for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append("{} {}".format( self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) else: ret.append("{}({}) {}".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) return "order by {} {}".format(",".join(ret), limit) def gen_having(self, sql, table, table_alias_dict): ret = [] for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg": col = self.gen_col(sql[i], table, table_alias_dict) else: col = "{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)) op = sql[i + 2] val = sql[i + 3] if val == "terminal": ret.append("{} {} '{}'".format(col, op, val)) else: val = self.gen_sql(val, table) ret.append("{} {} ({})".format(col, op, val)) return "having {}".format(",".join(ret)) def find_shortest_path(self, start, end, graph): stack = [[start, []]] visited = set() while len(stack) > 0: ele, history = stack.pop() if ele == end: return history for node in graph[ele]: if node[0] not in visited: stack.append((node[0], history + [(node[0], node[1])])) visited.add(node[0]) print("table {} table {}".format(start, end)) # print("could not find path!!!!!{}".format(self.path_not_found)) self.path_not_found += 1 # return [] def gen_from(self, candidate_tables, table): def find(d, col): if d[col] == -1: return col return find(d, d[col]) def union(d, c1, c2): r1 = find(d, c1) r2 = find(d, c2) if r1 == r2: return d[r1] = r2 ret = "" if len(candidate_tables) <= 1: if len(candidate_tables) == 1: ret = "from {}".format( table["table_names_original"][list(candidate_tables)[0]]) else: ret = "from {}".format(table["table_names_original"][0]) #TODO: temporarily settings return {}, ret # print("candidate:{}".format(candidate_tables)) table_alias_dict = {} uf_dict = {} for t in candidate_tables: uf_dict[t] = -1 idx = 1 graph = defaultdict(list) for acol, bcol in table["foreign_keys"]: t1 = table["column_names"][acol][0] t2 = table["column_names"][bcol][0] graph[t1].append((t2, (acol, bcol))) graph[t2].append((t1, (bcol, acol))) # if t1 in candidate_tables and t2 in candidate_tables: # r1 = find(uf_dict,t1) # r2 = find(uf_dict,t2) # if r1 == r2: # continue # union(uf_dict,t1,t2) # if len(ret) == 0: # ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2], # idx+1,idx,table["column_names_original"][acol][1],idx+1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx+1 # idx += 2 # else: # if t1 in table_alias_dict: # old_t = t1 # new_t = t2 # acol,bcol = bcol,acol # elif t2 in table_alias_dict: # old_t = t2 # new_t = t1 # else: # ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx, # table["table_names"][t2], # idx + 1, idx, # table["column_names_original"][acol][1], # idx + 1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx + 1 # idx += 2 # continue # ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1], # table_alias_dict[old_t],table["column_names_original"][bcol][1]) # table_alias_dict[new_t] = idx # idx += 1 # visited = set() candidate_tables = list(candidate_tables) start = candidate_tables[0] table_alias_dict[start] = idx idx += 1 ret = "from {} as T1".format(table["table_names_original"][start]) try: for end in candidate_tables[1:]: if end in table_alias_dict: continue path = self.find_shortest_path(start, end, graph) prev_table = start if not path: table_alias_dict[end] = idx idx += 1 ret = "{} join {} as T{}".format( ret, table["table_names_original"][end], table_alias_dict[end], ) continue for node, (acol, bcol) in path: if node in table_alias_dict: prev_table = node continue table_alias_dict[node] = idx idx += 1 ret = "{} join {} as T{} on T{}.{} = T{}.{}".format( ret, table["table_names_original"][node], table_alias_dict[node], table_alias_dict[prev_table], table["column_names_original"][acol][1], table_alias_dict[node], table["column_names_original"][bcol][1]) prev_table = node except: traceback.print_exc() print("db:{}".format(table["db_id"])) # print(table["db_id"]) return table_alias_dict, ret # if len(candidate_tables) != len(table_alias_dict): # print("error in generate from clause!!!!!") return table_alias_dict, ret def gen_sql(self, sql, table): select_clause = "" from_clause = "" groupby_clause = "" orderby_clause = "" having_clause = "" where_clause = "" nested_clause = "" cols = {} candidate_tables = set() nested_sql = {} nested_label = "" parent_sql = sql # if "sql" in sql: # sql = sql["sql"] if "nested_label" in sql: nested_label = sql["nested_label"] nested_sql = sql["nested_sql"] sql = sql["sql"] elif "sql" in sql: sql = sql["sql"] for key in sql: if key not in KW_WITH_COL: continue for item in sql[key]: if isinstance(item, tuple) and len(item) == 3: if table["column_names"][item[2]][0] != -1: candidate_tables.add(table["column_names"][item[2]][0]) table_alias_dict, from_clause = self.gen_from(candidate_tables, table) ret = [] if "select" in sql: select_clause = self.gen_select(sql["select"], "select", table, table_alias_dict) if len(select_clause) > 0: ret.append(select_clause) else: print("select not found:{}".format(parent_sql)) else: print("select not found:{}".format(parent_sql)) if len(from_clause) > 0: ret.append(from_clause) if "where" in sql: where_clause = self.gen_where(sql["where"], table, table_alias_dict) if len(where_clause) > 0: ret.append(where_clause) if "groupBy" in sql: ## DEBUG-ed order groupby_clause = self.gen_group_by(sql["groupBy"], "group by", table, table_alias_dict) if len(groupby_clause) > 0: ret.append(groupby_clause) if "orderBy" in sql: orderby_clause = self.gen_orderby(sql["orderBy"], table, table_alias_dict) if len(orderby_clause) > 0: ret.append(orderby_clause) if "having" in sql: having_clause = self.gen_having(sql["having"], table, table_alias_dict) if len(having_clause) > 0: ret.append(having_clause) if len(nested_label) > 0: nested_clause = "{} {}".format(nested_label, self.gen_sql(nested_sql, table)) if len(nested_clause) > 0: ret.append(nested_clause) return " ".join(ret) def check_acc(self, pred_sql, gt_sql): pass
def __init__(self, embeddings, N_word, hidden_dim, num_layers, gpu, num_augmentation=10000): self.embeddings = embeddings self.having_predictor = HavingPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.keyword_predictor = KeyWordPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.andor_predictor = AndOrPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.desasc_predictor = DesAscLimitPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.op_predictor = OpPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.col_predictor = ColPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.agg_predictor = AggPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.limit_value_predictor = LimitValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.distinct_predictor = DistinctPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.value_predictor = ValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() def get_model_path(model='having', batch_size=64, epoch=50, num_augmentation=num_augmentation, name_postfix=''): return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt' try: self.having_predictor.load(get_model_path('having')) self.keyword_predictor.load( get_model_path('keyword', epoch=300, num_augmentation=10000, name_postfix='kw2')) self.andor_predictor.load( get_model_path('andor', batch_size=256, num_augmentation=0)) self.desasc_predictor.load(get_model_path('desasc')) self.op_predictor.load(get_model_path('op', num_augmentation=10000)) self.col_predictor.load( get_model_path('column', epoch=300, num_augmentation=30000, name_postfix='rep2aug')) self.distinct_predictor.load( get_model_path('distinct', epoch=300, num_augmentation=0, name_postfix='dist2')) self.agg_predictor.load(get_model_path('agg', num_augmentation=0)) self.limit_value_predictor.load(get_model_path('limitvalue')) self.value_predictor.load( get_model_path('value', epoch=300, num_augmentation=10000, name_postfix='val2')) except FileNotFoundError as ex: print(ex) self.current_keyword = '' self.sql = None self.gpu = gpu if gpu: self.embeddings = self.embeddings.cuda()
class SyntaxSQL(): def __init__(self, embeddings, N_word, hidden_dim, num_layers, gpu, num_augmentation=10000): self.embeddings = embeddings self.having_predictor = HavingPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.keyword_predictor = KeyWordPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.andor_predictor = AndOrPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.desasc_predictor = DesAscLimitPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.op_predictor = OpPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.col_predictor = ColPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.agg_predictor = AggPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.limit_value_predictor = LimitValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.distinct_predictor = DistinctPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.value_predictor = ValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() def get_model_path(model='having', batch_size=64, epoch=50, num_augmentation=num_augmentation, name_postfix=''): return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt' try: self.having_predictor.load(get_model_path('having')) self.keyword_predictor.load( get_model_path('keyword', epoch=300, num_augmentation=10000, name_postfix='kw2')) self.andor_predictor.load( get_model_path('andor', batch_size=256, num_augmentation=0)) self.desasc_predictor.load(get_model_path('desasc')) self.op_predictor.load(get_model_path('op', num_augmentation=10000)) self.col_predictor.load( get_model_path('column', epoch=300, num_augmentation=30000, name_postfix='rep2aug')) self.distinct_predictor.load( get_model_path('distinct', epoch=300, num_augmentation=0, name_postfix='dist2')) self.agg_predictor.load(get_model_path('agg', num_augmentation=0)) self.limit_value_predictor.load(get_model_path('limitvalue')) self.value_predictor.load( get_model_path('value', epoch=300, num_augmentation=10000, name_postfix='val2')) except FileNotFoundError as ex: print(ex) self.current_keyword = '' self.sql = None self.gpu = gpu if gpu: self.embeddings = self.embeddings.cuda() def generate_select(self): self.current_keyword = 'select' self.generate_columns() def generate_where(self): self.current_keyword = 'where' self.generate_columns() def generate_ascdesc(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['having'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) ascdesc = self.desasc_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) ascdesc = SQL_ORDERBY_OPS[int(ascdesc)] self.sql.ORDERBY_OP += [ascdesc] if 'LIMIT' in ascdesc: limit_value = self.limit_value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx)[0] self.sql.LIMIT_VALUE = limit_value def generate_orderby(self): self.current_keyword = 'orderby' self.generate_columns() def generate_groupby(self): self.current_keyword = 'groupby' self.generate_columns() def generate_having(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['having'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) having = self.having_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) if having: self.current_keyword = 'having' self.generate_columns() def generate_keywords(self): self.generate_select() KEYWORDS = [ self.generate_where, self.generate_groupby, self.generate_orderby ] history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( history['keyword']) num_kw, kws = self.keyword_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.kw_emb_var, self.kw_len) if num_kw[0] == 0: return key_words = sorted(kws[0]) for key_word in key_words: KEYWORDS[int(key_word)]() def generate_andor(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['andor'][-1]]) andor = self.andor_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len) andor = SQL_COND_OPS[int(andor)] if self.current_keyword == 'where': self.sql.WHERE[-1].cond_op = andor elif self.current_keyword == 'having': self.sql.HAVING[-1].cond_op = andor def generate_op(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['op'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) op = self.op_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) op = SQL_OPS[int(op)] if self.current_keyword == 'where': self.sql.WHERE[-1].op = op else: self.sql.HAVING[-1].op = op return op def generate_distrinct(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['distinct'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) distinct = self.distinct_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) distinct = SQL_DISTINCT_OP[int(distinct)] if self.current_keyword == 'select': self.sql.COLS[-1].distinct = distinct elif self.current_keyword == 'orderby': self.sql.ORDERBY[-1].distinct = '' elif self.current_keyword == 'having': self.sql.HAVING[-1].distinct = distinct def generate_agg(self, column, early_return=False, force_agg=False): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['agg'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) agg = self.agg_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx, force_agg=force_agg) agg = SQL_AGG[int(agg)] if early_return is True: return agg if self.current_keyword == 'select': self.sql.COLS[-1].agg = agg elif self.current_keyword == 'orderby': self.sql.ORDERBY[-1].agg = agg elif self.current_keyword == 'having': self.sql.HAVING[-1].agg = agg def generate_between(self, column): ban_prediction = None for i in range(2): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['value'][-1]]) tokens = word_tokenize(str.lower(self.question)) int_tokens = [ text2int(token.replace('-', '').replace('.', '')).isdigit() for token in tokens ] num_tokens, start_index = self.value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, ban_prediction, int_tokens) num_tokens, start_index = int(num_tokens[0]), int(start_index[0]) try: value = ' '.join(tokens[start_index:start_index + num_tokens]) if self.current_keyword == 'where': if i == 0: self.sql.WHERE[-1].value = value ban_prediction = (num_tokens, start_index) else: self.sql.WHERE[-1].valueless = value elif self.current_keyword == 'having': if i == 0: self.sql.HAVING[-1].value = value ban_prediction = (num_tokens, start_index) else: self.sql.HAVING[-1].valueless = value except Exception as e: print(e) def generate_value(self, column): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['value'][-1]]) num_tokens, start_index = self.value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len) num_tokens, start_index = int(num_tokens[0]), int(start_index[0]) tokens = word_tokenize(str.lower(self.question)) try: value = ' '.join(tokens[start_index:start_index + num_tokens]) value = text2int(value) if self.current_keyword == 'where': self.sql.WHERE[-1].value = value elif self.current_keyword == 'having': self.sql.HAVING[-1].value = value except: pass def generate_columns(self): history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['col'][-1]]) num_cols, cols = self.col_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len) num_cols, cols = num_cols[0], cols[0] def exclude_all_from_columns(): excluded_idx = [ len(table.columns) for table in self.sql.database.tables ] _, cols_new = self.col_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, exclude_idx=excluded_idx) return self.sql.database.get_column_from_idx(cols_new[0][0]) for i, col in enumerate(cols): column = self.sql.database.get_column_from_idx(col) if self.current_keyword in ('where', 'having'): if self.current_keyword == 'where': if column.column_name == '*': column = exclude_all_from_columns() self.sql.WHERE += [Condition(column)] else: self.sql.HAVING += [Condition(column)] op = self.generate_op(column) if op == 'BETWEEN': self.generate_between(column) else: self.generate_value(column) if num_cols > 1 and i < (num_cols - 1): self.generate_andor(column) if self.current_keyword in ('orderby', 'select', 'having'): force_agg = False if self.current_keyword == 'orderby': self.sql.ORDERBY += [ColumnSelect(column)] if column.column_name == '*' and self.generate_agg( column, early_return=True) == '': column = exclude_all_from_columns() self.sql.ORDERBY[-1] = ColumnSelect(column) elif self.current_keyword == 'select': force_agg = len(set(cols)) < len(cols) self.sql.COLS += [ColumnSelect(column)] self.generate_agg(column, force_agg=force_agg) self.generate_distrinct(column) if self.current_keyword == 'groupby': if column.column_name == '*': column = exclude_all_from_columns() self.sql.GROUPBY += [ColumnSelect(column)] if self.current_keyword == 'groupby' and len(cols) > 0: self.generate_having(column) if self.current_keyword == 'orderby': self.generate_ascdesc(column) def GetSQL(self, question, database): self.sql = SQLStatement(query=None, database=database) self.question = question self.q_emb_var, self.q_len = self.embeddings(question) columns = self.sql.database.to_list() columns_all_splitted = [] for i, column in enumerate(columns): columns_tmp = [] for word in column: columns_tmp.extend(word.split('_')) columns_all_splitted += [columns_tmp] self.col_emb_var, self.col_len, self.col_name_len = self.embeddings.get_columns_emb( [columns_all_splitted]) _, num_cols_in_db, col_name_lens, embedding_dim = self.col_emb_var.shape self.col_emb_var = self.col_emb_var.reshape(num_cols_in_db, col_name_lens, embedding_dim) self.col_name_len = self.col_name_len.reshape(-1) self.kw_emb_var, self.kw_len = self.embeddings.get_history_emb( [['where', 'order by', 'group by']]) self.generate_keywords() return self.sql
class SyntaxSQL(): """ Main class for the SyntaxSQL model. This takes all the sub modules, and uses them to run a question through the syntax tree """ def __init__(self, embeddings, N_word, hidden_dim, num_layers, gpu, num_augmentation=10000): self.embeddings = embeddings self.having_predictor = HavingPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.keyword_predictor = KeyWordPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.andor_predictor = AndOrPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.desasc_predictor = DesAscLimitPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.op_predictor = OpPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.col_predictor = ColPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.agg_predictor = AggPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.limit_value_predictor = LimitValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.distinct_predictor = DistinctPredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() self.value_predictor = ValuePredictor(N_word=N_word, hidden_dim=hidden_dim, num_layers=num_layers, gpu=gpu).eval() def get_model_path(model='having', batch_size=64, epoch=50, num_augmentation=num_augmentation, name_postfix=''): return f'saved_models/{model}__num_layers={num_layers}__lr=0.001__dropout=0.3__batch_size={batch_size}__embedding_dim={N_word}__hidden_dim={hidden_dim}__epoch={epoch}__num_augmentation={num_augmentation}__{name_postfix}.pt' try: self.having_predictor.load(get_model_path('having')) self.keyword_predictor.load( get_model_path('keyword', epoch=300, num_augmentation=10000, name_postfix='kw2')) self.andor_predictor.load( get_model_path('andor', batch_size=256, num_augmentation=0)) self.desasc_predictor.load(get_model_path('desasc')) self.op_predictor.load(get_model_path('op', num_augmentation=10000)) self.col_predictor.load( get_model_path('column', epoch=300, num_augmentation=30000, name_postfix='rep2aug')) self.distinct_predictor.load( get_model_path('distinct', epoch=300, num_augmentation=0, name_postfix='dist2')) self.agg_predictor.load(get_model_path('agg', num_augmentation=0)) self.limit_value_predictor.load(get_model_path('limitvalue')) self.value_predictor.load( get_model_path('value', epoch=300, num_augmentation=10000, name_postfix='val2')) except FileNotFoundError as ex: print(ex) self.current_keyword = '' self.sql = None self.gpu = gpu if gpu: self.embeddings = self.embeddings.cuda() def generate_select(self): # All statements should start with a select statement self.current_keyword = 'select' self.generate_columns() def generate_where(self): self.current_keyword = 'where' self.generate_columns() def generate_ascdesc(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['having'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) ascdesc = self.desasc_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) ascdesc = SQL_ORDERBY_OPS[int(ascdesc)] self.sql.ORDERBY_OP += [ascdesc] if 'LIMIT' in ascdesc: limit_value = self.limit_value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx)[0] self.sql.LIMIT_VALUE = limit_value def generate_orderby(self): self.current_keyword = 'orderby' self.generate_columns() def generate_groupby(self): self.current_keyword = 'groupby' self.generate_columns() def generate_having(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['having'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) having = self.having_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) if having: self.current_keyword = 'having' self.generate_columns() def generate_keywords(self): self.generate_select() KEYWORDS = [ self.generate_where, self.generate_groupby, self.generate_orderby ] # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( history['keyword']) num_kw, kws = self.keyword_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.kw_emb_var, self.kw_len) if num_kw[0] == 0: return # We want the keywords in the same order as much as possible # Keywords are added FIFO queue, so sort it key_words = sorted(kws[0]) # Add other states to the list for key_word in key_words: KEYWORDS[int(key_word)]() def generate_andor(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['andor'][-1]]) andor = self.andor_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len) andor = SQL_COND_OPS[int(andor)] if self.current_keyword == 'where': self.sql.WHERE[-1].cond_op = andor elif self.current_keyword == 'having': self.sql.HAVING[-1].cond_op = andor def generate_op(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['op'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) op = self.op_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) op = SQL_OPS[int(op)] # Pick the current clause from the current keyword if self.current_keyword == 'where': self.sql.WHERE[-1].op = op else: self.sql.HAVING[-1].op = op return op def generate_distrinct(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['distinct'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) distinct = self.distinct_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx) distinct = SQL_DISTINCT_OP[int(distinct)] if self.current_keyword == 'select': self.sql.COLS[-1].distinct = distinct elif self.current_keyword == 'orderby': self.sql.ORDERBY[-1].distinct = '' elif self.current_keyword == 'having': self.sql.HAVING[-1].distinct = distinct def generate_agg(self, column, early_return=False, force_agg=False): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['agg'][-1]]) col_idx = self.sql.database.get_idx_from_column(column) agg = self.agg_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, col_idx, force_agg=force_agg) agg = SQL_AGG[int(agg)] if early_return is True: return agg if self.current_keyword == 'select': self.sql.COLS[-1].agg = agg elif self.current_keyword == 'orderby': self.sql.ORDERBY[-1].agg = agg elif self.current_keyword == 'having': self.sql.HAVING[-1].agg = agg def generate_between(self, column): ban_prediction = None # Make two predictions for i in range(2): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['value'][-1]]) tokens = word_tokenize(str.lower(self.question)) # Create mask for integer tokens int_tokens = [ text2int(token.replace('-', '').replace('.', '')).isdigit() for token in tokens ] num_tokens, start_index = self.value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, ban_prediction, int_tokens) num_tokens, start_index = int(num_tokens[0]), int(start_index[0]) try: value = ' '.join(tokens[start_index:start_index + num_tokens]) if self.current_keyword == 'where': if i == 0: self.sql.WHERE[-1].value = value ban_prediction = (num_tokens, start_index) else: self.sql.WHERE[-1].valueless = value elif self.current_keyword == 'having': if i == 0: self.sql.HAVING[-1].value = value ban_prediction = (num_tokens, start_index) else: self.sql.HAVING[-1].valueless = value # The value might not exist in the question, so just ignore it except Exception as e: print(e) def generate_value(self, column): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['value'][-1]]) num_tokens, start_index = self.value_predictor.predict( self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len) num_tokens, start_index = int(num_tokens[0]), int(start_index[0]) tokens = word_tokenize(str.lower(self.question)) try: value = ' '.join(tokens[start_index:start_index + num_tokens]) value = text2int(value) if self.current_keyword == 'where': self.sql.WHERE[-1].value = value elif self.current_keyword == 'having': self.sql.HAVING[-1].value = value # The value might not exist in the question, so just ignore it except: pass def generate_columns(self): # Get the history, from the current sql history = self.sql.generate_history() hs_emb_var, hs_len = self.embeddings.get_history_emb( [history['col'][-1]]) num_cols, cols = self.col_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len) # Predictions are returned as lists, but it only has one element num_cols, cols = num_cols[0], cols[0] def exclude_all_from_columns(): # Do not permit * as valid column in where/having clauses excluded_idx = [ len(table.columns) for table in self.sql.database.tables ] _, cols_new = self.col_predictor.predict(self.q_emb_var, self.q_len, hs_emb_var, hs_len, self.col_emb_var, self.col_len, self.col_name_len, exclude_idx=excluded_idx) return self.sql.database.get_column_from_idx(cols_new[0][0]) for i, col in enumerate(cols): column = self.sql.database.get_column_from_idx(col) if self.current_keyword in ('where', 'having'): # Add the column to the corresponding clause if self.current_keyword == 'where': if column.column_name == '*': column = exclude_all_from_columns() self.sql.WHERE += [Condition(column)] else: self.sql.HAVING += [Condition(column)] # We need the value and comparison operation in where/having clauses op = self.generate_op(column) if op == 'BETWEEN': self.generate_between(column) else: self.generate_value(column) # If we predict multiple columns in where or having, we need to also predict and/or if num_cols > 1 and i < (num_cols - 1): self.generate_andor(column) if self.current_keyword in ('orderby', 'select', 'having'): force_agg = False if self.current_keyword == 'orderby': self.sql.ORDERBY += [ColumnSelect(column)] if column.column_name == '*' and self.generate_agg( column, early_return=True) == '': column = exclude_all_from_columns() self.sql.ORDERBY[-1] = ColumnSelect(column) elif self.current_keyword == 'select': force_agg = len(set(cols)) < len(cols) self.sql.COLS += [ColumnSelect(column)] # Each column should have an aggregator self.generate_agg(column, force_agg=force_agg) self.generate_distrinct(column) if self.current_keyword == 'groupby': if column.column_name == '*': column = exclude_all_from_columns() self.sql.GROUPBY += [ColumnSelect(column)] if self.current_keyword == 'groupby' and len(cols) > 0: self.generate_having(column) if self.current_keyword == 'orderby': self.generate_ascdesc(column) def GetSQL(self, question, database): # Generate representation of the database in form of SQL clauses self.sql = SQLStatement(query=None, database=database) self.question = question self.q_emb_var, self.q_len = self.embeddings(question) columns = self.sql.database.to_list() # Get all columns from the database and split them columns_all_splitted = [] for i, column in enumerate(columns): columns_tmp = [] for word in column: columns_tmp.extend(word.split('_')) columns_all_splitted += [columns_tmp] # Get embedding for the columns and keywords self.col_emb_var, self.col_len, self.col_name_len = self.embeddings.get_columns_emb( [columns_all_splitted]) _, num_cols_in_db, col_name_lens, embedding_dim = self.col_emb_var.shape self.col_emb_var = self.col_emb_var.reshape(num_cols_in_db, col_name_lens, embedding_dim) self.col_name_len = self.col_name_len.reshape(-1) self.kw_emb_var, self.kw_len = self.embeddings.get_history_emb( [['where', 'order by', 'group by']]) # Start recursively generating the sql history starting with the keywords, select and so on. self.generate_keywords() return self.sql
def train_feedback(nlq, db_name, correct_query, toy, word_emb): """ Arguments: nlq: english question (tokenization is done here) - get from Flask (User) db_name: name of the database the query targets - get from Flask (User) correct_query: the ground truth query supplied by the user(s) - get from Flask toy: uses a small example of word embeddings to debug faster """ ITER = 21 SAVED_MODELS_FOLDER = "saved_models" OUTPUT_PATH = "output_inference.txt" HISTORY_TYPE = "full" GPU_ENABLE = False TRAIN_EMB = False TABLE_TYPE = "std" DATA_ROOT = "generated_data" use_hs = True if HISTORY_TYPE == "no": HISTORY_TYPE = "full" use_hs = False """ Model Hyperparameters """ N_word = 300 # word embedding dimension B_word = 42 # 42B tokens in the Glove pretrained embeddings N_h = 300 # hidden size dimension N_depth = 2 # if toy: USE_SMALL = True # GPU=True GPU = GPU_ENABLE BATCH_SIZE = 20 else: USE_SMALL = False # GPU=True GPU = GPU_ENABLE BATCH_SIZE = 64 # TRAIN_ENTRY=(False, True, False) # (AGG, SEL, COND) # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY learning_rate = 1e-4 # GENERATE CORRECT QUERY DATASET table_data_path = "./data/spider/tables.json" table_dict = get_table_dict(table_data_path) train_data_path = "./data/spider/train_spider.json" train_data = json.load(open(train_data_path)) sql = correct_query #"SELECT name , country , age FROM singer ORDER BY age DESC" db_id = db_name #"concert_singer" table_file = table_data_path # "tables.json" schemas, db_names, tables = get_schemas_from_json(table_file) schema = schemas[db_id] table = tables[db_id] schema = Schema(schema, table) sql_label = get_sql(schema, sql) correct_query_data = { "multi_sql_dataset": [], "keyword_dataset": [], "col_dataset": [], "op_dataset": [], "agg_dataset": [], "root_tem_dataset": [], "des_asc_dataset": [], "having_dataset": [], "andor_dataset": [] } parser_item_with_long_history( tokenize(nlq), #item["question_toks"], sql_label, #item["sql"], table_dict[db_name], #table_dict[item["db_id"]], [], correct_query_data) # print("\nCorrect query dataset: {}".format(correct_query_data)) for train_component in TRAIN_COMPONENTS: print("\nTRAIN COMPONENT: {}".format(train_component)) # Check if the compenent to be trained is an actual component if train_component not in TRAIN_COMPONENTS: print("Invalid train component") exit(1) """ Read in the data """ train_data = load_train_dev_dataset(train_component, "train", HISTORY_TYPE, DATA_ROOT) # print("train_data type: {}".format(type(train_data))) dev_data = load_train_dev_dataset(train_component, "dev", HISTORY_TYPE, DATA_ROOT) # sql_data, table_data, val_sql_data, val_table_data, \ # test_sql_data, test_table_data, \ # TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL) if GPU_ENABLE: map_to = "gpu" else: map_to = "cpu" # Selecting which Model to Train model = None if train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/multi_sql_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "keyword": model = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/keyword_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "col": model = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/col_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "op": model = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/op_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "agg": model = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/agg_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "root_tem": model = RootTeminalPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/root_tem_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "des_asc": model = DesAscLimitPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/des_asc_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "having": model = HavingPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/having_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "andor": model = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/andor_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) # model = SQLNet(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) print("finished build model") print_flag = False embed_layer = WordEmbedding(word_emb, N_word, gpu=GPU, SQL_TOK=SQL_TOK, trainable=TRAIN_EMB) print("start training") best_acc = 0.0 for i in range(ITER): print('ITER %d @ %s' % (i + 1, datetime.datetime.now())) # arguments of epoch_train # model, optimizer, batch_size, component,embed_layer,data, table_type # print(' Loss = %s' % epoch_train( # model, optimizer, BATCH_SIZE, # args.train_component, # embed_layer, # train_data, # table_type=args.table_type)) print('Total Loss = %s' % epoch_feedback_train(model=model, optimizer=optimizer, batch_size=BATCH_SIZE, component=train_component, embed_layer=embed_layer, data=train_data, table_type=TABLE_TYPE, nlq=nlq, db_name=db_name, correct_query=correct_query, correct_query_data=correct_query_data)) # Check improvement every 10 iterations if i % 10 == 0: acc = epoch_acc(model, BATCH_SIZE, train_component, embed_layer, dev_data, table_type=TABLE_TYPE) if acc > best_acc: best_acc = acc print("Save model...") torch.save( model.state_dict(), SAVED_MODELS_FOLDER + "/{}_models.dump".format(train_component))