def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): super(SuperModel, self).__init__() self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.trainable_emb = trainable_emb self.table_type = table_type self.use_hs = use_hs self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>'] # word embedding layer self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, trainable=trainable_emb) # initial all modules self.multi_sql = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.multi_sql.eval() self.key_word = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.key_word.eval() self.col = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.col.eval() self.op = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.op.eval() self.agg = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.agg.eval() self.root_teminal = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.root_teminal.eval() self.des_asc = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.des_asc.eval() self.having = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=gpu, use_hs=use_hs) self.having.eval() self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.andor.eval() self.softmax = nn.Softmax(dim = 1) #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss() self.bce_logit = nn.BCEWithLogitsLoss() self.sigm = nn.Sigmoid() if gpu: self.cuda() self.path_not_found = 0
print("word_emb type = {}".format(type(word_emb))) # print("random element from word_emb = {}".format(word_emb[random.choice(list(word_emb.keys()))])) print("finished loading word embedding") #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt") # Selecting which Model to Train model = None if GPU_ENABLE: map_to = "gpu" else: map_to = "cpu" if args.train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/multi_sql_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif args.train_component == "keyword": model = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/keyword_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to))
class SuperModel(nn.Module): def __init__(self, word_emb, N_word, N_h=300, N_depth=2, gpu=True, trainable_emb=False, table_type="std", use_hs=True): super(SuperModel, self).__init__() self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.trainable_emb = trainable_emb self.table_type = table_type self.use_hs = use_hs self.SQL_TOK = [ '<UNK>', '<END>', 'WHERE', 'AND', 'EQL', 'GT', 'LT', '<BEG>' ] # word embedding layer self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, trainable=trainable_emb) # initial all modules self.multi_sql = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.multi_sql.eval() self.key_word = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.key_word.eval() self.col = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.col.eval() self.op = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.op.eval() self.agg = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.agg.eval() self.root_teminal = RootTeminalPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.root_teminal.eval() self.des_asc = DesAscLimitPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.des_asc.eval() self.having = HavingPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.having.eval() self.andor = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=gpu, use_hs=use_hs) self.andor.eval() self.softmax = nn.Softmax() #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss() self.bce_logit = nn.BCEWithLogitsLoss() self.sigm = nn.Sigmoid() if gpu: self.cuda() self.path_not_found = 0 def forward(self, q_seq, history, tables): # if self.part: # return self.part_forward(q_seq,history,tables) # else: return self.full_forward(q_seq, history, tables) def full_forward(self, q_seq, history, tables): B = len(q_seq) # print("q_seq:{}".format(q_seq)) # print("Batch size:{}".format(B)) q_emb_var, q_len = self.embed_layer.gen_x_q_batch(q_seq) col_seq = to_batch_tables(tables, B, self.table_type) col_emb_var, col_name_len, col_len = self.embed_layer.gen_col_batch( col_seq) mkw_emb_var = self.embed_layer.gen_word_list_embedding( ["none", "except", "intersect", "union"], (B)) mkw_len = np.full(q_len.shape, 4, dtype=np.int64) kw_emb_var = self.embed_layer.gen_word_list_embedding( ["where", "group by", "order by"], (B)) kw_len = np.full(q_len.shape, 3, dtype=np.int64) stack = Stack() stack.push(("root", None)) history = [["root"]] * B andor_cond = "" has_limit = False # sql = {} current_sql = {} sql_stack = [] idx_stack = [] kw_stack = [] kw = "" nested_label = "" has_having = False timeout = time.time( ) + 2 # set timer to prevent infinite recursion in SQL generation failed = False while not stack.isEmpty(): if time.time() > timeout: failed = True break vet = stack.pop() # print(vet) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch(history) if len(idx_stack) > 0 and stack.size() < idx_stack[-1]: # print("pop!!!!!!!!!!!!!!!!!!!!!!") idx_stack.pop() current_sql = sql_stack.pop() kw = kw_stack.pop() # current_sql = current_sql["sql"] # history.append(vet) # print("hs_emb:{} hs_len:{}".format(hs_emb_var.size(),hs_len.size())) if isinstance(vet, tuple) and vet[0] == "root": if history[0][-1] != "root": history[0].append("root") hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) if vet[1] != "original": idx_stack.append(stack.size()) sql_stack.append(current_sql) kw_stack.append(kw) else: idx_stack.append(stack.size()) sql_stack.append(sql_stack[-1]) kw_stack.append(kw) if "sql" in current_sql: current_sql["nested_sql"] = {} current_sql["nested_label"] = nested_label current_sql = current_sql["nested_sql"] elif isinstance(vet[1], dict): vet[1]["sql"] = {} current_sql = vet[1]["sql"] elif vet[1] != "original": current_sql["sql"] = {} current_sql = current_sql["sql"] # print("q_emb_var:{} hs_emb_var:{} mkw_emb_var:{}".format(q_emb_var.size(),hs_emb_var.size(),mkw_emb_var.size())) if vet[1] == "nested" or vet[1] == "original": stack.push("none") history[0].append("none") else: score = self.multi_sql.forward(q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len) label = np.argmax(score[0].data.cpu().numpy()) label = SQL_OPS[label] history[0].append(label) stack.push(label) if label != "none": nested_label = label elif vet in ('intersect', 'except', 'union'): stack.push(("root", "nested")) stack.push(("root", "original")) # history[0].append("root") elif vet == "none": score = self.key_word.forward(q_emb_var, q_len, hs_emb_var, hs_len, kw_emb_var, kw_len) kw_num_score, kw_score = [x.data.cpu().numpy() for x in score] # print("kw_num_score:{}".format(kw_num_score)) # print("kw_score:{}".format(kw_score)) num_kw = np.argmax(kw_num_score[0]) kw_score = list(np.argsort(-kw_score[0])[:num_kw]) kw_score.sort(reverse=True) # print("num_kw:{}".format(num_kw)) for kw in kw_score: stack.push(KW_OPS[kw]) stack.push("select") elif vet in ("select", "orderBy", "where", "groupBy", "having"): kw = vet current_sql[kw] = [] history[0].append(vet) stack.push(("col", vet)) # score = self.andor.forward(q_emb_var,q_len,hs_emb_var,hs_len) # label = score[0].data.cpu().numpy() # andor_cond = COND_OPS[label] # history.append("") # elif vet == "groupBy": # score = self.having.forward(q_emb_var,q_len,hs_emb_var,hs_len,col_emb_var,col_len,) elif isinstance(vet, tuple) and vet[0] == "col": # print("q_emb_var:{} hs_emb_var:{} col_emb_var:{}".format(q_emb_var.size(), hs_emb_var.size(),col_emb_var.size())) score = self.col.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len) col_num_score, col_score = [ x.data.cpu().numpy() for x in score ] col_num = np.argmax(col_num_score[0]) + 1 # double check cols = np.argsort(-col_score[0])[:col_num] # print(col_num) # print("col_num_score:{}".format(col_num_score)) # print("col_score:{}".format(col_score)) for col in cols: if vet[1] == "where": stack.push(("op", "where", col)) elif vet[1] != "groupBy": stack.push(("agg", vet[1], col)) elif vet[1] == "groupBy": history[0].append(index_to_column_name(col, tables)) current_sql[kw].append( index_to_column_name(col, tables)) #predict and or or when there is multi col in where condition if col_num > 1 and vet[1] == "where": score = self.andor.forward(q_emb_var, q_len, hs_emb_var, hs_len) label = np.argmax(score[0].data.cpu().numpy()) andor_cond = COND_OPS[label] current_sql[kw].append(andor_cond) if vet[1] == "groupBy" and col_num > 0: score = self.having.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, cols[0], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) if label == 1: has_having = (label == 1) # stack.insert(-col_num,"having") stack.push("having") # history.append(index_to_column_name(cols[-1], tables[0])) elif isinstance(vet, tuple) and vet[0] == "agg": history[0].append(index_to_column_name(vet[2], tables)) if vet[1] not in ("having", "orderBy"): #DEBUG-ed 20180817 try: current_sql[kw].append( index_to_column_name(vet[2], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.agg.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) agg_num_score, agg_score = [ x.data.cpu().numpy() for x in score ] agg_num = np.argmax(agg_num_score[0]) # double check agg_idxs = np.argsort(-agg_score[0])[:agg_num] # print("agg:{}".format([AGG_OPS[agg] for agg in agg_idxs])) if len(agg_idxs) > 0: history[0].append(AGG_OPS[agg_idxs[0]]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append(AGG_OPS[agg_idxs[0]]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg_idxs[0]])) #DEBUG-ed 20180817 else: stack.push( ("op", "having", vet[2], AGG_OPS[agg_idxs[0]])) for agg in agg_idxs[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(AGG_OPS[agg]) if vet[1] not in ("having", "orderBy"): current_sql[kw].append( index_to_column_name(vet[2], tables)) current_sql[kw].append(AGG_OPS[agg]) elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], AGG_OPS[agg])) else: stack.push(("op", "having", vet[2], agg_idxs)) if len(agg_idxs) == 0: if vet[1] not in ("having", "orderBy"): current_sql[kw].append("none_agg") elif vet[1] == "orderBy": stack.push(("des_asc", vet[2], "none_agg")) else: stack.push(("op", "having", vet[2], "none_agg")) # current_sql[kw].append([AGG_OPS[agg] for agg in agg_idxs]) # if vet[1] == "having": # stack.push(("op","having",vet[2],agg_idxs)) # if vet[1] == "orderBy": # stack.push(("des_asc",vet[2],agg_idxs)) # if vet[1] == "groupBy" and has_having: # stack.push("having") elif isinstance(vet, tuple) and vet[0] == "op": if vet[1] == "where": # current_sql[kw].append(index_to_column_name(vet[2], tables)) history[0].append(index_to_column_name(vet[2], tables)) hs_emb_var, hs_len = self.embed_layer.gen_x_history_batch( history) score = self.op.forward(q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[2], dtype=np.int64)) op_num_score, op_score = [x.data.cpu().numpy() for x in score] op_num = np.argmax( op_num_score[0] ) + 1 # num_score 0 maps to 1 in truth, must have at least one op ops = np.argsort(-op_score[0])[:op_num] # current_sql[kw].append([NEW_WHERE_OPS[op] for op in ops]) if op_num > 0: history[0].append(NEW_WHERE_OPS[ops[0]]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], ops[0])) else: stack.push(("root_teminal", vet[2], ops[0])) # current_sql[kw].append(NEW_WHERE_OPS[ops[0]]) for op in ops[1:]: history[0].append(index_to_column_name(vet[2], tables)) history[0].append(NEW_WHERE_OPS[op]) # current_sql[kw].append(index_to_column_name(vet[2], tables)) # current_sql[kw].append(NEW_WHERE_OPS[op]) if vet[1] == "having": stack.push(("root_teminal", vet[2], vet[3], op)) else: stack.push(("root_teminal", vet[2], op)) # stack.push(("root_teminal",vet[2])) elif isinstance(vet, tuple) and vet[0] == "root_teminal": score = self.root_teminal.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) label = ROOT_TERM_OPS[label] if len(vet) == 4: current_sql[kw].append(index_to_column_name( vet[1], tables)) current_sql[kw].append(vet[2]) current_sql[kw].append(NEW_WHERE_OPS[vet[3]]) else: # print("kw:{}".format(kw)) try: current_sql[kw].append( index_to_column_name(vet[1], tables)) except Exception as e: # print(e) traceback.print_exc() print("history:{},current_sql:{} stack:{}".format( history[0], current_sql, stack.items)) print("idx_stack:{}".format(idx_stack)) print("sql_stack:{}".format(sql_stack)) exit(1) current_sql[kw].append(NEW_WHERE_OPS[vet[2]]) if label == "root": history[0].append("root") current_sql[kw].append({}) # current_sql = current_sql[kw][-1] stack.push(("root", current_sql[kw][-1])) else: current_sql[kw].append("terminal") elif isinstance(vet, tuple) and vet[0] == "des_asc": current_sql[kw].append(index_to_column_name(vet[1], tables)) current_sql[kw].append(vet[2]) score = self.des_asc.forward( q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len, np.full(B, vet[1], dtype=np.int64)) label = np.argmax(score[0].data.cpu().numpy()) dec_asc, has_limit = DEC_ASC_OPS[label] history[0].append(dec_asc) current_sql[kw].append(dec_asc) current_sql[kw].append(has_limit) # print("{}".format(current_sql)) if failed: return None print("history:{}".format(history[0])) if len(sql_stack) > 0: current_sql = sql_stack[0] # print("{}".format(current_sql)) return current_sql def gen_col(self, col, table, table_alias_dict): colname = table["column_names_original"][col[2]][1] table_idx = table["column_names_original"][col[2]][0] if table_idx not in table_alias_dict: return colname return "T{}.{}".format(table_alias_dict[table_idx], colname) def gen_group_by(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql)): # if len(sql[i+1]) == 0: # if sql[i+1] == "none_agg": ret.append(self.gen_col(sql[i], table, table_alias_dict)) # else: # ret.append("{}({})".format(sql[i+1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_select(self, sql, kw, table, table_alias_dict): ret = [] for i in range(0, len(sql), 2): # if len(sql[i+1]) == 0: if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append(self.gen_col(sql[i], table, table_alias_dict)) else: ret.append("{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict))) # for agg in sql[i+1]: # ret.append("{}({})".format(agg,gen_col(sql[i],table,table_alias_dict))) return "{} {}".format(kw, ",".join(ret)) def gen_where(self, sql, table, table_alias_dict): if len(sql) == 0: return "" start_idx = 0 andor = "and" if isinstance(sql[0], basestring): start_idx += 1 andor = sql[0] ret = [] for i in range(start_idx, len(sql), 3): col = self.gen_col(sql[i], table, table_alias_dict) op = sql[i + 1] val = sql[i + 2] where_item = "" if val == "terminal": where_item = "{} {} '{}'".format(col, op, val) else: val = self.gen_sql(val, table) where_item = "{} {} ({})".format(col, op, val) if op == "between": #TODO temprarily fixed where_item += " and 'terminal'" ret.append(where_item) return "where {}".format(" {} ".format(andor).join(ret)) def gen_orderby(self, sql, table, table_alias_dict): ret = [] limit = "" if sql[-1] == True: limit = "limit 1" for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg" or not isinstance( sql[i + 1], basestring): #DEBUG-ed 20180817 ret.append("{} {}".format( self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) else: ret.append("{}({}) {}".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict), sql[i + 2])) return "order by {} {}".format(",".join(ret), limit) def gen_having(self, sql, table, table_alias_dict): ret = [] for i in range(0, len(sql), 4): if sql[i + 1] == "none_agg": col = self.gen_col(sql[i], table, table_alias_dict) else: col = "{}({})".format( sql[i + 1], self.gen_col(sql[i], table, table_alias_dict)) op = sql[i + 2] val = sql[i + 3] if val == "terminal": ret.append("{} {} '{}'".format(col, op, val)) else: val = self.gen_sql(val, table) ret.append("{} {} ({})".format(col, op, val)) return "having {}".format(",".join(ret)) def find_shortest_path(self, start, end, graph): stack = [[start, []]] visited = set() while len(stack) > 0: ele, history = stack.pop() if ele == end: return history for node in graph[ele]: if node[0] not in visited: stack.append((node[0], history + [(node[0], node[1])])) visited.add(node[0]) print("table {} table {}".format(start, end)) # print("could not find path!!!!!{}".format(self.path_not_found)) self.path_not_found += 1 # return [] def gen_from(self, candidate_tables, table): def find(d, col): if d[col] == -1: return col return find(d, d[col]) def union(d, c1, c2): r1 = find(d, c1) r2 = find(d, c2) if r1 == r2: return d[r1] = r2 ret = "" if len(candidate_tables) <= 1: if len(candidate_tables) == 1: ret = "from {}".format( table["table_names_original"][list(candidate_tables)[0]]) else: ret = "from {}".format(table["table_names_original"][0]) #TODO: temporarily settings return {}, ret # print("candidate:{}".format(candidate_tables)) table_alias_dict = {} uf_dict = {} for t in candidate_tables: uf_dict[t] = -1 idx = 1 graph = defaultdict(list) for acol, bcol in table["foreign_keys"]: t1 = table["column_names"][acol][0] t2 = table["column_names"][bcol][0] graph[t1].append((t2, (acol, bcol))) graph[t2].append((t1, (bcol, acol))) # if t1 in candidate_tables and t2 in candidate_tables: # r1 = find(uf_dict,t1) # r2 = find(uf_dict,t2) # if r1 == r2: # continue # union(uf_dict,t1,t2) # if len(ret) == 0: # ret = "from {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(table["table_names"][t1],idx,table["table_names"][t2], # idx+1,idx,table["column_names_original"][acol][1],idx+1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx+1 # idx += 2 # else: # if t1 in table_alias_dict: # old_t = t1 # new_t = t2 # acol,bcol = bcol,acol # elif t2 in table_alias_dict: # old_t = t2 # new_t = t1 # else: # ret = "{} join {} as T{} join {} as T{} on T{}.{}=T{}.{}".format(ret,table["table_names"][t1], idx, # table["table_names"][t2], # idx + 1, idx, # table["column_names_original"][acol][1], # idx + 1, # table["column_names_original"][bcol][1]) # table_alias_dict[t1] = idx # table_alias_dict[t2] = idx + 1 # idx += 2 # continue # ret = "{} join {} as T{} on T{}.{}=T{}.{}".format(ret,new_t,idx,idx,table["column_names_original"][acol][1], # table_alias_dict[old_t],table["column_names_original"][bcol][1]) # table_alias_dict[new_t] = idx # idx += 1 # visited = set() candidate_tables = list(candidate_tables) start = candidate_tables[0] table_alias_dict[start] = idx idx += 1 ret = "from {} as T1".format(table["table_names_original"][start]) try: for end in candidate_tables[1:]: if end in table_alias_dict: continue path = self.find_shortest_path(start, end, graph) prev_table = start if not path: table_alias_dict[end] = idx idx += 1 ret = "{} join {} as T{}".format( ret, table["table_names_original"][end], table_alias_dict[end], ) continue for node, (acol, bcol) in path: if node in table_alias_dict: prev_table = node continue table_alias_dict[node] = idx idx += 1 ret = "{} join {} as T{} on T{}.{} = T{}.{}".format( ret, table["table_names_original"][node], table_alias_dict[node], table_alias_dict[prev_table], table["column_names_original"][acol][1], table_alias_dict[node], table["column_names_original"][bcol][1]) prev_table = node except: traceback.print_exc() print("db:{}".format(table["db_id"])) # print(table["db_id"]) return table_alias_dict, ret # if len(candidate_tables) != len(table_alias_dict): # print("error in generate from clause!!!!!") return table_alias_dict, ret def gen_sql(self, sql, table): select_clause = "" from_clause = "" groupby_clause = "" orderby_clause = "" having_clause = "" where_clause = "" nested_clause = "" cols = {} candidate_tables = set() nested_sql = {} nested_label = "" parent_sql = sql # if "sql" in sql: # sql = sql["sql"] if "nested_label" in sql: nested_label = sql["nested_label"] nested_sql = sql["nested_sql"] sql = sql["sql"] elif "sql" in sql: sql = sql["sql"] for key in sql: if key not in KW_WITH_COL: continue for item in sql[key]: if isinstance(item, tuple) and len(item) == 3: if table["column_names"][item[2]][0] != -1: candidate_tables.add(table["column_names"][item[2]][0]) table_alias_dict, from_clause = self.gen_from(candidate_tables, table) ret = [] if "select" in sql: select_clause = self.gen_select(sql["select"], "select", table, table_alias_dict) if len(select_clause) > 0: ret.append(select_clause) else: print("select not found:{}".format(parent_sql)) else: print("select not found:{}".format(parent_sql)) if len(from_clause) > 0: ret.append(from_clause) if "where" in sql: where_clause = self.gen_where(sql["where"], table, table_alias_dict) if len(where_clause) > 0: ret.append(where_clause) if "groupBy" in sql: ## DEBUG-ed order groupby_clause = self.gen_group_by(sql["groupBy"], "group by", table, table_alias_dict) if len(groupby_clause) > 0: ret.append(groupby_clause) if "orderBy" in sql: orderby_clause = self.gen_orderby(sql["orderBy"], table, table_alias_dict) if len(orderby_clause) > 0: ret.append(orderby_clause) if "having" in sql: having_clause = self.gen_having(sql["having"], table, table_alias_dict) if len(having_clause) > 0: ret.append(having_clause) if len(nested_label) > 0: nested_clause = "{} {}".format(nested_label, self.gen_sql(nested_sql, table)) if len(nested_clause) > 0: ret.append(nested_clause) return " ".join(ret) def check_acc(self, pred_sql, gt_sql): pass
print("Invalid train component") exit(1) train_data = load_train_dev_dataset(args.train_component, "train", args.history_type, args.data_root) dev_data = load_train_dev_dataset(args.train_component, "dev", args.history_type, args.data_root) # sql_data, table_data, val_sql_data, val_table_data, \ # test_sql_data, test_table_data, \ # TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL) model = None start_time = default_timer() if args.train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) elif args.train_component == "keyword": model = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) elif args.train_component == "col": model = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) elif args.train_component == "op":
word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), load_used=False, use_small=USE_SMALL) print("finished load word embedding: {}".format(time.time() - start_time)) #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt") model = None if BERT: bert_model = BertModel.from_pretrained('bert-base-uncased') if GPU: bert_model.cuda() def berter(q, q_len): return encode_question(bert_model, q, q_len) bert = berter else: bert_model = None bert = None if args.train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "keyword": model = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "col": model = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "op": model = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "agg": model = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "root_tem": model = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "des_asc": model = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "having": model = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth, gpu=GPU, use_hs=use_hs, bert=bert) elif args.train_component == "andor":
if args.train_component not in TRAIN_COMPONENTS: print("Invalid train component") exit(1) train_data = load_train_dev_dataset(args.train_component, "train", args.history_type, args.data_root) dev_data = load_train_dev_dataset(args.train_component, "dev", args.history_type, args.data_root) # sql_data, table_data, val_sql_data, val_table_data, \ # test_sql_data, test_table_data, \ # TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL) word_emb = load_word_emb('glove/glove.%dB.%dd.txt'%(B_word,N_word), \ load_used=args.train_emb, use_small=USE_SMALL) print("finished load word embedding") #word_emb = load_concat_wemb('glove/glove.42B.300d.txt', "/data/projects/paraphrase/generation/para-nmt-50m/data/paragram_sl999_czeng.txt") model = None if args.train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "keyword": model = KeyWordPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "col": model = ColPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "op": model = OpPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "agg": model = AggPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "root_tem": model = RootTeminalPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "des_asc": model = DesAscLimitPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "having": model = HavingPredictor(N_word=N_word,N_h=N_h,N_depth=N_depth,gpu=GPU, use_hs=use_hs) elif args.train_component == "andor":
def train_feedback(nlq, db_name, correct_query, toy, word_emb): """ Arguments: nlq: english question (tokenization is done here) - get from Flask (User) db_name: name of the database the query targets - get from Flask (User) correct_query: the ground truth query supplied by the user(s) - get from Flask toy: uses a small example of word embeddings to debug faster """ ITER = 21 SAVED_MODELS_FOLDER = "saved_models" OUTPUT_PATH = "output_inference.txt" HISTORY_TYPE = "full" GPU_ENABLE = False TRAIN_EMB = False TABLE_TYPE = "std" DATA_ROOT = "generated_data" use_hs = True if HISTORY_TYPE == "no": HISTORY_TYPE = "full" use_hs = False """ Model Hyperparameters """ N_word = 300 # word embedding dimension B_word = 42 # 42B tokens in the Glove pretrained embeddings N_h = 300 # hidden size dimension N_depth = 2 # if toy: USE_SMALL = True # GPU=True GPU = GPU_ENABLE BATCH_SIZE = 20 else: USE_SMALL = False # GPU=True GPU = GPU_ENABLE BATCH_SIZE = 64 # TRAIN_ENTRY=(False, True, False) # (AGG, SEL, COND) # TRAIN_AGG, TRAIN_SEL, TRAIN_COND = TRAIN_ENTRY learning_rate = 1e-4 # GENERATE CORRECT QUERY DATASET table_data_path = "./data/spider/tables.json" table_dict = get_table_dict(table_data_path) train_data_path = "./data/spider/train_spider.json" train_data = json.load(open(train_data_path)) sql = correct_query #"SELECT name , country , age FROM singer ORDER BY age DESC" db_id = db_name #"concert_singer" table_file = table_data_path # "tables.json" schemas, db_names, tables = get_schemas_from_json(table_file) schema = schemas[db_id] table = tables[db_id] schema = Schema(schema, table) sql_label = get_sql(schema, sql) correct_query_data = { "multi_sql_dataset": [], "keyword_dataset": [], "col_dataset": [], "op_dataset": [], "agg_dataset": [], "root_tem_dataset": [], "des_asc_dataset": [], "having_dataset": [], "andor_dataset": [] } parser_item_with_long_history( tokenize(nlq), #item["question_toks"], sql_label, #item["sql"], table_dict[db_name], #table_dict[item["db_id"]], [], correct_query_data) # print("\nCorrect query dataset: {}".format(correct_query_data)) for train_component in TRAIN_COMPONENTS: print("\nTRAIN COMPONENT: {}".format(train_component)) # Check if the compenent to be trained is an actual component if train_component not in TRAIN_COMPONENTS: print("Invalid train component") exit(1) """ Read in the data """ train_data = load_train_dev_dataset(train_component, "train", HISTORY_TYPE, DATA_ROOT) # print("train_data type: {}".format(type(train_data))) dev_data = load_train_dev_dataset(train_component, "dev", HISTORY_TYPE, DATA_ROOT) # sql_data, table_data, val_sql_data, val_table_data, \ # test_sql_data, test_table_data, \ # TRAIN_DB, DEV_DB, TEST_DB = load_dataset(args.dataset, use_small=USE_SMALL) if GPU_ENABLE: map_to = "gpu" else: map_to = "cpu" # Selecting which Model to Train model = None if train_component == "multi_sql": model = MultiSqlPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/multi_sql_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "keyword": model = KeyWordPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/keyword_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "col": model = ColPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/col_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "op": model = OpPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/op_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "agg": model = AggPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/agg_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "root_tem": model = RootTeminalPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/root_tem_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "des_asc": model = DesAscLimitPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load( "{}/des_asc_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "having": model = HavingPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/having_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) elif train_component == "andor": model = AndOrPredictor(N_word=N_word, N_h=N_h, N_depth=N_depth, gpu=GPU, use_hs=use_hs) model.load_state_dict( torch.load("{}/andor_models.dump".format(SAVED_MODELS_FOLDER), map_location=map_to)) # model = SQLNet(word_emb, N_word=N_word, gpu=GPU, trainable_emb=args.train_emb) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0) print("finished build model") print_flag = False embed_layer = WordEmbedding(word_emb, N_word, gpu=GPU, SQL_TOK=SQL_TOK, trainable=TRAIN_EMB) print("start training") best_acc = 0.0 for i in range(ITER): print('ITER %d @ %s' % (i + 1, datetime.datetime.now())) # arguments of epoch_train # model, optimizer, batch_size, component,embed_layer,data, table_type # print(' Loss = %s' % epoch_train( # model, optimizer, BATCH_SIZE, # args.train_component, # embed_layer, # train_data, # table_type=args.table_type)) print('Total Loss = %s' % epoch_feedback_train(model=model, optimizer=optimizer, batch_size=BATCH_SIZE, component=train_component, embed_layer=embed_layer, data=train_data, table_type=TABLE_TYPE, nlq=nlq, db_name=db_name, correct_query=correct_query, correct_query_data=correct_query_data)) # Check improvement every 10 iterations if i % 10 == 0: acc = epoch_acc(model, BATCH_SIZE, train_component, embed_layer, dev_data, table_type=TABLE_TYPE) if acc > best_acc: best_acc = acc print("Save model...") torch.save( model.state_dict(), SAVED_MODELS_FOLDER + "/{}_models.dump".format(train_component))