from utils import gen_query_acc, gen_sql_query import sys from word_mapping import * filename = 'glove/glove.42B.300d.txt' agg_checkpoint_name = 'saved_models/agg_predictor.pth' select_checkpoint_name = 'saved_models/sel_predictor.pth' cond_checkpoint_name = 'saved_models/cond_predictor.pth' N_word = 300 batch_size = 10 hidden_dim = 100 n_epochs = 5 table_name = 'EMPLOYEE' word_embed = load_word_emb(filename) word_emb = WordEmbedding(N_word, word_embed) model = Model(hidden_dim, N_word, word_emb) model.agg_predictor.load_state_dict(torch.load(agg_checkpoint_name)) model.cond_predictor.load_state_dict(torch.load(cond_checkpoint_name)) model.sel_predictor.load_state_dict(torch.load(select_checkpoint_name)) model.eval() sentence = sys.argv[1] sentence = process_sentence(sentence) question = [sentence.split(' ')]
from net_utils import run_lstm from model import Model import torch.nn as nn import torch.optim as optim import torch import numpy as np from utils import train_model,test_model from model import Model from torch.autograd import Variable import torch.nn.functional as F N_word = 50 batch_size =10 hidden_dim = 100 word_embed = load_word_emb('glove/glove.6B.50d.txt') train = SQLDataset('train') #train , valid = SQLDataset('train') , SQLDataset('dev') train_dataloader = DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=1,collate_fn=collate_fn) #valid_dataloader = DataLoader(valid,batch_size=batch_size,shuffle=True,num_workers=1,collate_fn=collate_fn) #test = SQLDataset('test') #test_dataloader = DataLoader(test,batch_size = batch_size, shuffle=True, num_workers=1,collate_fn=collate_fn)