def train(**kwargs): args = Config() args.parse(kwargs) loss_func = loss_function score_func = batch_scorer train_set = DataSet(args.sog_processed + 'train/') dev_set = DataSet(args.sog_processed + 'dev/') train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=own_collate_fn, num_workers=20) dev_loader = DataLoader(dev_set, batch_size=args.batch_size, shuffle=True, collate_fn=own_collate_fn) vocab = pk.load(open('Predictor/Utils/sogou_vocab.pkl', 'rb')) eos_id, sos_id = vocab.token2id['<EOS>'], vocab.token2id['<BOS>'] args.eos_id = eos_id args.sos_id = sos_id model = getattr(Models, args.model_name)(matrix=vocab.matrix, args=args) trainner = Trainner_transformer(args, vocab) trainner.train(model, loss_func, score_func, train_loader, dev_loader, resume=args.resume, exp_root=args.exp_root)
def test(**kwargs): args = Config() test_set = DataSet(args.processed_folder + 'test/') test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, collate_fn=own_collate_fn) vocab = pk.load(open('Predictor/Utils/vocab.pkl', 'rb')) eos_id, sos_id = vocab.token2id['<EOS>'], vocab.token2id['<BOS>'] args.eos_id = eos_id args.sos_id = sos_id model = getattr(Models, args.model_name)(matrix=vocab.matrix, args=args) load = _load('ckpt/saved_models/2018_08_20_02_12_38_0.2602508540088274', model) model = load['model'] model.to('cuda') #TODO complete load_state_dict and predict model.teacher_forcing_ratio = -100 with t.no_grad(): for data in test_loader: context, title, context_lenths, title_lenths = [ i.to('cuda') for i in data ] token_id, prob_vector, token_lenth, attention_matrix = model( context, context_lenths, title) score = batch_scorer(token_id.tolist(), title.tolist(), args.eos_id) context_word = [[vocab.from_id_token(id.item()) for id in sample] for sample in context] words = [[vocab.from_id_token(id.item()) for id in sample] for sample in token_id] title_words = [[vocab.from_id_token(id.item()) for id in sample] for sample in title] for i in zip(context_word, words, title_words): a = input('next') print(f'context:{i[0]},pre:{i[1]}, tru:{i[2]}, score:{score}')
net = self.linear1(inputs.transpose(1, 2)) net = self.relu(net) net = self.drop(net) net = self.linear2(net) net = net.transpose(1, 2) net = self.drop(net) return net if __name__ == '__main__': from configs import Config import pickle as pk vocab = pk.load(open('Predictor/Utils/vocab.pkl', 'rb')) args = Config() args.sos_id = vocab.token2id['<BOS>'] args.eos_id = vocab.token2id['<EOS>'] args.batch_size = 4 print(args.sos_id) matrix = vocab.matrix model = UniversalTransformer(args, matrix) from torch.utils.data import DataLoader from DataSets import DataSet from DataSets import own_collate_fn from Predictor.Utils import batch_scorer from Predictor.Utils.loss import loss_function train_set = DataSet(args.processed_folder + 'train/') train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=own_collate_fn) vocab = pk.load(open('Predictor/Utils/vocab.pkl', 'rb'))