def predict(args): best_model_fname = args.model_name input_encoder = encoder(len(vocab), args.embedding_size, args.hidden_size, args.para_init) inter_atten = atten(args.hidden_size, args.cls_num, args.para_init) input_encoder.cuda() inter_atten.cuda() input_encoder.load_state_dict( torch.load(best_model_fname + '_input-encoder.pt')) inter_atten.load_state_dict( torch.load(best_model_fname + '_inter-atten.pt')) input_encoder.eval() inter_atten.eval() res = [] stance = [] for i, query in enumerate(query_iter): print(i) test_src_batch = query.query test_src_batch = test_src_batch.repeat(args.batch_size * 4, 1) test_src_batch = test_src_batch.cuda() docs_poss = torch.FloatTensor([]) for j, batch in enumerate(doc_iter): test_tgt_batch = batch.doc if args.cuda: test_tgt_batch = test_tgt_batch.cuda() test_src_linear, test_tgt_linear = input_encoder( test_src_batch[:test_tgt_batch.size(0)], test_tgt_batch) log_prob = inter_atten(test_src_linear, test_tgt_linear) _, preds = log_prob.data.max(dim=1) docs_poss = torch.cat((docs_poss, log_prob[:, 1].cpu().data), dim=0) stance = np.concatenate((stance, preds), axis=0) res.append(docs_poss.cpu().data.numpy()) np.save('res.npy', res) np.save('stance.npy', stance) print(np.array(res).shape) print(np.array(stance).shape)
def predict(args): if args.max_length < 0: args.max_length = 9999 # initialize the logger # create logger logger_name = "mylog" logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # file handler fh = logging.FileHandler(args.log_dir + args.log_fname) fh.setLevel(logging.INFO) logger.addHandler(fh) out_file = open(args.log_dir + 'test_out.txt', 'w') err_file = open(args.log_dir + 'test_err.txt', 'w') # stream handler console = logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) torch.cuda.set_device(args.gpu_id) for arg in vars(args): logger.info(str(arg) + ' ' + str(getattr(args, arg))) # load test data logger.info('loading data...') test_data = snli_data(args.test_file, args.max_length) test_batches = test_data.batches test_id_batches = test_data.id_batches logger.info('test size # sent ' + str(test_data.size)) if args.ignore_ques: test_data.have_ques = 0 # get input embeddings logger.info('loading input embeddings...') word_vecs = w2v(args.w2v_file).word_vecs # build the model input_encoder = encoder(word_vecs.size()[0], args.embedding_size, args.hidden_size, args.para_init) input_encoder.embedding.weight.data.copy_(word_vecs) input_encoder.embedding.weight.requires_grad = False inter_atten = atten(args.hidden_size, 3, args.para_init) seq_atten = SeqAttnMatch(args.hidden_size, args.para_init) input_encoder.cuda(args.gpu_id) inter_atten.cuda(args.gpu_id) seq_atten.cuda(args.gpu_id) input_encoder.load_state_dict( torch.load(args.trained_encoder, map_location={'cuda:0': 'cuda:1'})) inter_atten.load_state_dict( torch.load(args.trained_attn, map_location={'cuda:0': 'cuda:1'})) seq_atten.load_state_dict( torch.load(args.seq_attn, map_location={'cuda:0': 'cuda:1'})) input_encoder.eval() inter_atten.eval() seq_atten.eval() tot_corr = 0.0 tot_eg = 0.0 for i in range(len(test_batches)): test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch, test_lbl_batch = test_batches[ i] test_src_ids, test_targ_ids = test_id_batches[i] test_src_batch = Variable(test_src_batch.cuda(args.gpu_id)) test_tgt_batch = Variable(test_tgt_batch.cuda(args.gpu_id)) test_src_ques_batch = Variable(test_src_ques_batch.cuda(args.gpu_id)) test_targ_ques_batch = Variable(test_targ_ques_batch.cuda(args.gpu_id)) test_src_linear, test_tgt_linear, test_src_ques_linear, test_targ_ques_linear = input_encoder( test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch) if test_data.have_ques == 1: #Prepare masks test_src_ques_mask = Variable( torch.from_numpy(np.zeros( test_src_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_targ_ques_mask = Variable( torch.from_numpy(np.zeros( test_targ_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_src_linear = seq_atten.forward(test_src_linear, test_src_ques_linear, test_src_ques_mask) test_tgt_linear = seq_atten.forward(test_tgt_linear, test_targ_ques_linear, test_targ_ques_mask) log_prob = inter_atten(test_src_linear, test_tgt_linear) norm_probs = F.softmax(log_prob) probs, predict = norm_probs.data.max(dim=1) j = 0 corr = 0 for m_id, m_prob in zip(probs, predict): if m_prob == test_lbl_batch[j]: corr += 1 else: err_file.write( str(test_src_ids[j]) + '\t' + str(test_targ_ids[j]) + '\t' + str(m_prob) + '\t' + str(test_lbl_batch[j]) + '\n') out_file.write( str(test_src_ids[j]) + '\t' + str(test_targ_ids[j]) + '\t' + str(m_id) + '\t' + str(m_prob) + '\n') j += 1 tot_corr += corr tot_eg += j out_file.close() err_file.close() print('Accuracy: ' + str(tot_corr / tot_eg))
def train(args): if args.max_length < 0: args.max_length = 9999 # initialize the logger # create logger logger_name = "mylog" logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # file handler fh = logging.FileHandler(args.log_dir + args.log_fname) fh.setLevel(logging.INFO) logger.addHandler(fh) # stream handler console = logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) torch.cuda.set_device(args.gpu_id) for arg in vars(args): logger.info(str(arg) + ' ' + str(getattr(args, arg))) # load train/dev/test data # train data logger.info('loading data...') train_data = snli_data(args.train_file, args.max_length) train_batches = train_data.batches train_lbl_size = 3 dev_data = snli_data(args.dev_file, args.max_length) dev_batches = dev_data.batches test_data = snli_data(args.test_file, args.max_length) test_batches = test_data.batches logger.info('train size # sent ' + str(train_data.size)) logger.info('dev size # sent ' + str(dev_data.size)) logger.info('test size # sent ' + str(test_data.size)) # get input embeddings logger.info('loading input embeddings...') word_vecs = w2v(args.w2v_file).word_vecs best_dev = [] # (epoch, dev_acc) # build the model input_encoder = encoder(word_vecs.size(0), args.embedding_size, args.hidden_size, args.para_init) input_encoder.embedding.weight.data.copy_(word_vecs) input_encoder.embedding.weight.requires_grad = False seq_atten = SeqAttnMatch(args.hidden_size, args.para_init) inter_atten = atten(args.hidden_size, train_lbl_size, args.para_init) input_encoder.cuda(args.gpu_id) inter_atten.cuda(args.gpu_id) seq_atten.cuda(args.gpu_id) if args.resume: logger.info('loading trained model.') input_encoder.load_state_dict( torch.load(args.trained_encoder, map_location={'cuda:0': 'cuda:1'})) inter_atten.load_state_dict( torch.load(args.trained_attn, map_location={'cuda:0': 'cuda:1'})) seq_atten.load_state_dict( torch.load(args.seq_attn, map_location={'cuda:0': 'cuda:1'})) #test before training starts input_encoder.eval() seq_atten.eval() inter_atten.eval() correct = 0. total = 0. logger.info('test before training starts') for i in range(len(test_data.batches)): test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch, test_lbl_batch = test_data.batches[ i] test_src_batch = Variable(test_src_batch.cuda(args.gpu_id)) test_tgt_batch = Variable(test_tgt_batch.cuda(args.gpu_id)) test_src_ques_batch = Variable(test_src_ques_batch.cuda(args.gpu_id)) test_targ_ques_batch = Variable(test_targ_ques_batch.cuda(args.gpu_id)) test_lbl_batch = Variable(test_lbl_batch.cuda(args.gpu_id)) test_src_linear, test_tgt_linear, test_src_ques_linear, test_targ_ques_linear = input_encoder( test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch) if test_data.have_ques == 1: #Prepare masks test_src_ques_mask = Variable( torch.from_numpy(np.zeros( test_src_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_targ_ques_mask = Variable( torch.from_numpy(np.zeros( test_targ_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_src_linear = seq_atten.forward(test_src_linear, test_src_ques_linear, test_src_ques_mask) test_tgt_linear = seq_atten.forward(test_tgt_linear, test_targ_ques_linear, test_targ_ques_mask) log_prob = inter_atten(test_src_linear, test_tgt_linear) _, predict = log_prob.data.max(dim=1) total += test_lbl_batch.data.size()[0] correct += torch.sum(predict == test_lbl_batch.data) test_acc = correct / total logger.info('init-test-acc %.3f' % (test_acc)) input_encoder.train() seq_atten.train() inter_atten.train() para1 = filter(lambda p: p.requires_grad, input_encoder.parameters()) para2 = inter_atten.parameters() para3 = seq_atten.parameters() if args.optimizer == 'Adagrad': input_optimizer = optim.Adagrad(para1, lr=args.lr, weight_decay=args.weight_decay) inter_atten_optimizer = optim.Adagrad(para2, lr=args.lr, weight_decay=args.weight_decay) seq_atten_optimizer = optim.Adagrad(para3, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'Adadelta': input_optimizer = optim.Adadelta(para1, lr=args.lr) inter_atten_optimizer = optim.Adadelta(para2, lr=args.lr) seq_atten_optimizer = optim.Adadelta(para3, lr=args.lr) else: logger.info('No Optimizer.') sys.exit() if args.resume: input_optimizer.load_state_dict( torch.load(args.input_optimizer, map_location={'cuda:0': 'cuda:1'})) inter_atten_optimizer.load_state_dict( torch.load(args.inter_atten_optimizer, map_location={'cuda:0': 'cuda:1'})) seq_atten_optimizer.load_state_dict( torch.load(args.seq_atten_optimizer, map_location={'cuda:0': 'cuda:1'})) criterion = nn.NLLLoss(size_average=True) # criterion = nn.CrossEntropyLoss() logger.info('start to train...') for k in range(args.epoch): total = 0. correct = 0. loss_data = 0. train_sents = 0. shuffle(train_batches) timer = time.time() for i in range(len(train_batches)): train_src_batch, train_tgt_batch, train_src_ques_batch, train_targ_ques_batch, train_lbl_batch = train_batches[ i] train_src_batch = Variable(train_src_batch.cuda(args.gpu_id)) train_tgt_batch = Variable(train_tgt_batch.cuda(args.gpu_id)) train_src_ques_batch = Variable( train_src_ques_batch.cuda(args.gpu_id)) train_targ_ques_batch = Variable( train_targ_ques_batch.cuda(args.gpu_id)) train_lbl_batch = Variable(train_lbl_batch.cuda(args.gpu_id)) batch_size = train_src_batch.size(0) train_sents += batch_size input_optimizer.zero_grad() inter_atten_optimizer.zero_grad() seq_atten_optimizer.zero_grad() # initialize the optimizer if k == 0 and args.optimizer == 'Adagrad' and not args.resume: for group in input_optimizer.param_groups: for p in group['params']: state = input_optimizer.state[p] state['sum'] += args.Adagrad_init for group in inter_atten_optimizer.param_groups: for p in group['params']: state = inter_atten_optimizer.state[p] state['sum'] += args.Adagrad_init for group in seq_atten_optimizer.param_groups: for p in group['params']: state = seq_atten_optimizer.state[p] state['sum'] += args.Adagrad_init elif k == 0 and args.optimizer == 'Adagrad' and args.seq_atten_optimizer == 'none': for group in seq_atten_optimizer.param_groups: for p in group['params']: state = seq_atten_optimizer.state[p] state['sum'] += args.Adagrad_init train_src_linear, train_tgt_linear, train_src_ques_linear, train_targ_ques_linear = input_encoder( train_src_batch, train_tgt_batch, train_src_ques_batch, train_targ_ques_batch) if train_data.have_ques == 1: #Prepare masks train_src_ques_mask = Variable( torch.from_numpy( np.zeros( train_src_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) train_targ_ques_mask = Variable( torch.from_numpy( np.zeros(train_targ_ques_linear.data.shape[:2])).byte( ).cuda(args.gpu_id)) train_src_linear = seq_atten.forward(train_src_linear, train_src_ques_linear, train_src_ques_mask) train_tgt_linear = seq_atten.forward(train_tgt_linear, train_targ_ques_linear, train_targ_ques_mask) log_prob = inter_atten(train_src_linear, train_tgt_linear) loss = criterion(log_prob, train_lbl_batch) loss.backward() grad_norm = 0. para_norm = 0. for m in input_encoder.modules(): if isinstance(m, nn.Linear): grad_norm += m.weight.grad.data.norm()**2 para_norm += m.weight.data.norm()**2 if m.bias: grad_norm += m.bias.grad.data.norm()**2 para_norm += m.bias.data.norm()**2 for m in inter_atten.modules(): if isinstance(m, nn.Linear): grad_norm += m.weight.grad.data.norm()**2 para_norm += m.weight.data.norm()**2 if int(m.bias.data[0]): grad_norm += m.bias.grad.data.norm()**2 para_norm += m.bias.data.norm()**2 if train_data.have_ques == 1: for m in seq_atten.modules(): if isinstance(m, nn.Linear): grad_norm += m.weight.grad.data.norm()**2 para_norm += m.weight.data.norm()**2 if int(m.bias.data[0]): grad_norm += m.bias.grad.data.norm()**2 para_norm += m.bias.data.norm()**2 grad_norm**0.5 para_norm**0.5 shrinkage = args.max_grad_norm / (grad_norm + 0.01) if shrinkage < 1: for m in input_encoder.modules(): # print m if isinstance(m, nn.Linear): m.weight.grad.data = m.weight.grad.data * shrinkage for m in inter_atten.modules(): # print m if isinstance(m, nn.Linear): m.weight.grad.data = m.weight.grad.data * shrinkage m.bias.grad.data = m.bias.grad.data * shrinkage if train_data.have_ques == 1: for m in inter_atten.modules(): # print m if isinstance(m, nn.Linear): m.weight.grad.data = m.weight.grad.data * shrinkage m.bias.grad.data = m.bias.grad.data * shrinkage input_optimizer.step() inter_atten_optimizer.step() if train_data.have_ques == 1: seq_atten_optimizer.step() _, predict = log_prob.data.max(dim=1) total += train_lbl_batch.data.size()[0] correct += torch.sum(predict == train_lbl_batch.data) loss_data += (loss.data[0] * batch_size ) # / train_lbl_batch.data.size()[0]) if (i + 1) % args.display_interval == 0: logger.info( 'epoch %d, batches %d|%d, train-acc %.3f, loss %.3f, para-norm %.3f, grad-norm %.3f, time %.2fs, ' % (k, i + 1, len(train_batches), correct / total, loss_data / train_sents, para_norm, grad_norm, time.time() - timer)) train_sents = 0. timer = time.time() loss_data = 0. correct = 0. total = 0. if i == len(train_batches) - 1: logger.info( 'epoch %d, batches %d|%d, train-acc %.3f, loss %.3f, para-norm %.3f, grad-norm %.3f, time %.2fs, ' % (k, i + 1, len(train_batches), correct / total, loss_data / train_sents, para_norm, grad_norm, time.time() - timer)) train_sents = 0. timer = time.time() loss_data = 0. correct = 0. total = 0. # evaluate if (k + 1) % args.dev_interval == 0: input_encoder.eval() inter_atten.eval() seq_atten.eval() correct = 0. total = 0. for i in range(len(dev_batches)): dev_src_batch, dev_tgt_batch, dev_src_ques_batch, dev_targ_ques_batch, dev_lbl_batch = dev_batches[ i] dev_src_batch = Variable(dev_src_batch.cuda(args.gpu_id)) dev_tgt_batch = Variable(dev_tgt_batch.cuda(args.gpu_id)) dev_src_ques_batch = Variable( dev_src_ques_batch.cuda(args.gpu_id)) dev_targ_ques_batch = Variable( dev_targ_ques_batch.cuda(args.gpu_id)) dev_lbl_batch = Variable(dev_lbl_batch.cuda(args.gpu_id)) dev_src_linear, dev_tgt_linear, dev_src_ques_linear, dev_targ_ques_linear = input_encoder( dev_src_batch, dev_tgt_batch, dev_src_ques_batch, dev_targ_ques_batch) if dev_data.have_ques == 1: #Prepare masks dev_src_ques_mask = Variable( torch.from_numpy( np.zeros(dev_src_ques_linear.data.shape[:2])).byte( ).cuda(args.gpu_id)) dev_targ_ques_mask = Variable( torch.from_numpy( np.zeros(dev_targ_ques_linear.data.shape[:2])). byte().cuda(args.gpu_id)) dev_src_linear = seq_atten.forward(dev_src_linear, dev_src_ques_linear, dev_src_ques_mask) dev_tgt_linear = seq_atten.forward(dev_tgt_linear, dev_targ_ques_linear, dev_targ_ques_mask) log_prob = inter_atten(dev_src_linear, dev_tgt_linear) _, predict = log_prob.data.max(dim=1) total += dev_lbl_batch.data.size()[0] correct += torch.sum(predict == dev_lbl_batch.data) dev_acc = correct / total logger.info('dev-acc %.3f' % (dev_acc)) if (k + 1) / args.dev_interval == 1: model_fname = '%s%s_epoch-%d_dev-acc-%.3f' % ( args.model_path, args.log_fname.split('.')[0], k, dev_acc) torch.save(input_encoder.state_dict(), model_fname + '_input-encoder.pt') torch.save(inter_atten.state_dict(), model_fname + '_inter-atten.pt') torch.save(seq_atten.state_dict(), model_fname + '_seq-atten.pt') torch.save(input_optimizer.state_dict(), model_fname + '_input-optimizer.pt') torch.save(inter_atten_optimizer.state_dict(), model_fname + '_inter-atten-optimizer.pt') torch.save(seq_atten_optimizer.state_dict(), model_fname + '_seq-atten-optimizer.pt') best_dev.append((k, dev_acc, model_fname)) logger.info('current best-dev:') for t in best_dev: logger.info('\t%d %.3f' % (t[0], t[1])) logger.info('save model!') else: if dev_acc > best_dev[-1][1]: model_fname = '%s%s_epoch-%d_dev-acc-%.3f' % ( args.model_path, args.log_fname.split('.')[0], k, dev_acc) torch.save(input_encoder.state_dict(), model_fname + '_input-encoder.pt') torch.save(inter_atten.state_dict(), model_fname + '_inter-atten.pt') torch.save(seq_atten.state_dict(), model_fname + '_seq-atten.pt') torch.save(input_optimizer.state_dict(), model_fname + '_input-optimizer.pt') torch.save(inter_atten_optimizer.state_dict(), model_fname + '_inter-atten-optimizer.pt') torch.save(seq_atten_optimizer.state_dict(), model_fname + '_seq-atten-optimizer.pt') best_dev.append((k, dev_acc, model_fname)) logger.info('current best-dev:') for t in best_dev: logger.info('\t%d %.3f' % (t[0], t[1])) logger.info('save model!') input_encoder.train() inter_atten.train() seq_atten.train() logger.info('training end!') # test best_model_fname = best_dev[-1][2] input_encoder.load_state_dict( torch.load(best_model_fname + '_input-encoder.pt', map_location={'cuda:0': 'cuda:1'})) inter_atten.load_state_dict( torch.load(best_model_fname + '_inter-atten.pt', map_location={'cuda:0': 'cuda:1'})) seq_atten.load_state_dict( torch.load(best_model_fname + '_seq-atten.pt', map_location={'cuda:0': 'cuda:1'})) input_encoder.eval() inter_atten.eval() seq_atten.eval() correct = 0. total = 0. for i in range(len(test_batches)): test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch, test_lbl_batch = test_batches[ i] test_src_batch = Variable(test_src_batch.cuda(args.gpu_id)) test_tgt_batch = Variable(test_tgt_batch.cuda(args.gpu_id)) test_src_ques_batch = Variable(test_src_ques_batch.cuda(args.gpu_id)) test_targ_ques_batch = Variable(test_targ_ques_batch.cuda(args.gpu_id)) test_lbl_batch = Variable(test_lbl_batch.cuda(args.gpu_id)) test_src_linear, test_tgt_linear, test_src_ques_linear, test_targ_ques_linear = input_encoder( test_src_batch, test_tgt_batch, test_src_ques_batch, test_targ_ques_batch) if test_data.have_ques == 1: #Prepare masks test_src_ques_mask = Variable( torch.from_numpy(np.zeros( test_src_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_targ_ques_mask = Variable( torch.from_numpy(np.zeros( test_targ_ques_linear.data.shape[:2])).byte().cuda( args.gpu_id)) test_src_linear = seq_atten.forward(test_src_linear, test_src_ques_linear, test_src_ques_mask) test_tgt_linear = seq_atten.forward(test_tgt_linear, test_targ_ques_linear, test_targ_ques_mask) log_prob = inter_atten(test_src_linear, test_tgt_linear) _, predict = log_prob.data.max(dim=1) total += test_lbl_batch.data.size()[0] correct += torch.sum(predict == test_lbl_batch.data) test_acc = correct / total logger.info('test-acc %.3f' % (test_acc))
def train(args): if args.max_length < 0: args.max_length = 9999 # initialize the logger # create logger logger_name = "mylog" logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # file handler fh = logging.FileHandler(args.log_dir + args.log_fname) fh.setLevel(logging.INFO) logger.addHandler(fh) # stream handler console = logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) torch.cuda.set_device(args.gpu_id) for arg in vars(args): logger.info(str(arg) + ' ' + str(getattr(args, arg))) # load train/dev/test data # train data logger.info('loading data...') train_data = snli_data(args.train_file, args.max_length, meta=args.dev_file) train_batches = train_data.batches train_lbl_size = 3 dev_data = snli_data(args.test_file, args.max_length) # todo: use a better dev_data (from train or val. not test) dev_batches = dev_data.batches # test_data = snli_data(args.test_file, args.max_length) # test_batches = test_data.batches logger.info('train size # sent ' + str(train_data.size)) logger.info('dev size # sent ' + str(dev_data.size)) # get input embeddings logger.info('loading input embeddings...') word_vecs = w2v(args.w2v_file).word_vecs best_dev = [] # (epoch, dev_acc) logger.info('loading input embeddings...') word_vecs = w2v(args.w2v_file).word_vecs best_dev = [] # (epoch, dev_acc) # build the model input_encoder = encoder(word_vecs.size(0), args.embedding_size, args.hidden_size, args.para_init) input_encoder.embedding.weight.data.copy_(word_vecs) input_encoder.embedding.weight.requires_grad = False inter_atten = atten(args.hidden_size, train_lbl_size, args.para_init) input_encoder.cuda() inter_atten.cuda() para1 = filter(lambda p: p.requires_grad, input_encoder.parameters()) para2 = inter_atten.parameters() if args.optimizer == 'Adagrad': input_optimizer = optim.Adagrad(para1, lr=args.lr, weight_decay=args.weight_decay) inter_atten_optimizer = optim.Adagrad(para2, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'Adadelta': input_optimizer = optim.Adadelta(para1, lr=args.lr) inter_atten_optimizer = optim.Adadelta(para2, lr=args.lr) else: logger.info('No Optimizer.') sys.exit() criterion = nn.NLLLoss(size_average=True) saved_state = torch.load( './data/runs/420_10_10_epoch10/log54_epoch-189_dev-acc-0.833_input-encoder.pt' ) input_encoder.load_state_dict(saved_state) saved_state = torch.load( './data/runs/420_10_10_epoch10/log54_epoch-189_dev-acc-0.833_inter-atten.pt' ) inter_atten.load_state_dict(saved_state) prec1 = generate_intermediate_outputs(train_data, (input_encoder, inter_atten), criterion, 0)
def train(args): # initialize the logger # create logger logger_name = "mylog" logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # file handler fh = logging.FileHandler(args.log_dir + args.log_fname) fh.setLevel(logging.INFO) logger.addHandler(fh) # stream handler console = logging.StreamHandler() console.setLevel(logging.INFO) logger.addHandler(console) for arg in vars(args): logger.info(str(arg) + ' ' + str(getattr(args, arg))) # load train/dev/test data # train data logger.info('loading data...') # text_field = data.Field(lower=True, batch_first = True) # label_field = data.Field(sequential=False) # train_iter, dev_iter, test_iter = snli(text_field, label_field, repeat=False) batch_num_train = int(len(train_iter.dataset) / args.batch_size) logger.info('train size # sent ' + str(len(train_iter.dataset))) logger.info('dev size # sent ' + str(len(dev_iter.dataset))) # get input embeddings logger.info('loading input embeddings...') logger.info('vocab size ' + str(len(vocab))) best_dev = [] # (epoch, dev_acc) # build the model input_encoder = encoder(len(vocab), args.embedding_size, args.hidden_size, args.para_init) # input_encoder.embedding.weight.data.copy_(vocab.vectors) # input_encoder.embedding.weight.requires_grad = False inter_atten = atten(args.hidden_size, args.cls_num, args.para_init) input_encoder.cuda() inter_atten.cuda() para1 = filter(lambda p: p.requires_grad, input_encoder.parameters()) para2 = inter_atten.parameters() if args.optimizer == 'Adagrad': input_optimizer = optim.Adagrad(para1, lr=args.lr, weight_decay=args.weight_decay) inter_atten_optimizer = optim.Adagrad(para2, lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'Adadelta': input_optimizer = optim.Adadelta(para1, lr=args.lr) inter_atten_optimizer = optim.Adadelta(para2, lr=args.lr) else: logger.info('No Optimizer.') sys.exit() criterion = nn.NLLLoss(reduction='elementwise_mean') # criterion = nn.CrossEntropyLoss() logger.info('start to train...') for k in range(args.epoch): total = 0. correct = 0. loss_data = 0. train_sents = 0. timer = time.time() for i, batch in enumerate(train_iter): train_src_batch, train_tgt_batch, train_lbl_batch = make_batch( batch, args) # train_src_batch, train_tgt_batch, train_lbl_batch = batch.premise, batch.hypothesis, batch.label # train_lbl_batch.sub_(1) if args.cuda: train_src_batch, train_tgt_batch, train_lbl_batch = train_src_batch.cuda( ), train_tgt_batch.cuda(), train_lbl_batch.cuda() train_sents += args.batch_size input_optimizer.zero_grad() inter_atten_optimizer.zero_grad() # initialize the optimizer if k == 0 and optim == 'Adagrad': for group in input_optimizer.param_groups: for p in group['params']: state = input_optimizer.state[p] state['sum'] += args.Adagrad_init for group in inter_atten_optimizer.param_groups: for p in group['params']: state = inter_atten_optimizer.state[p] state['sum'] += args.Adagrad_init train_src_linear, train_tgt_linear = input_encoder( train_src_batch, train_tgt_batch) log_prob = inter_atten(train_src_linear, train_tgt_linear) loss = criterion(log_prob, train_lbl_batch) loss.backward() grad_norm = 0. para_norm = 0. for m in input_encoder.modules(): if isinstance(m, nn.Linear): grad_norm += m.weight.grad.data.norm()**2 para_norm += m.weight.data.norm()**2 if m.bias is not None: grad_norm += m.bias.grad.data.norm()**2 para_norm += m.bias.data.norm()**2 for m in inter_atten.modules(): if isinstance(m, nn.Linear): grad_norm += m.weight.grad.data.norm()**2 para_norm += m.weight.data.norm()**2 if m.bias is not None: grad_norm += m.bias.grad.data.norm()**2 para_norm += m.bias.data.norm()**2 grad_norm**0.5 para_norm**0.5 shrinkage = args.max_grad_norm / grad_norm if shrinkage < 1 or shrinkage > 50: for m in input_encoder.modules(): # print m if isinstance(m, nn.Linear): m.weight.grad.data = m.weight.grad.data * shrinkage for m in inter_atten.modules(): # print m if isinstance(m, nn.Linear): m.weight.grad.data = m.weight.grad.data * shrinkage m.bias.grad.data = m.bias.grad.data * shrinkage input_optimizer.step() inter_atten_optimizer.step() _, predict = log_prob.data.max(dim=1) total += train_lbl_batch.data.size()[0] correct += torch.sum(predict == train_lbl_batch.data) loss_data += (loss.data * args.batch_size ) # / train_lbl_batch.data.size()[0]) if (i + 1) % args.display_interval == 0: logger.info( 'epoch %d, batches %d|%d, train-acc %.3f, loss %.3f, para-norm %.3f, grad-norm %.3f, time %.2fs, ' % (k, i + 1, batch_num_train, correct.cpu().data.numpy() / total, loss_data / train_sents, para_norm, grad_norm, time.time() - timer)) train_sents = 0. timer = time.time() loss_data = 0. correct = 0. total = 0. if i == batch_num_train - 1: logger.info( 'epoch %d, batches %d|%d, train-acc %.3f, loss %.3f, para-norm %.3f, grad-norm %.3f, time %.2fs, ' % (k, i + 1, batch_num_train, correct.cpu().data.numpy() / total, loss_data / train_sents, para_norm, grad_norm, time.time() - timer)) train_sents = 0. timer = time.time() loss_data = 0. correct = 0. total = 0. # evaluate if (k + 1) % args.dev_interval == 0: input_encoder.eval() inter_atten.eval() correct = 0. total = 0. for batch in dev_iter: dev_src_batch, dev_tgt_batch, dev_lbl_batch = make_batch( batch, args) # dev_src_batch, dev_tgt_batch, dev_lbl_batch = batch.premise, batch.hypothesis, batch.label # dev_lbl_batch.sub_(1) # if args.cuda: # dev_src_batch, dev_tgt_batch, dev_lbl_batch = dev_src_batch.cuda(), dev_tgt_batch.cuda(), dev_lbl_batch.cuda() # if dev_lbl_batch.data.size(0) == 1: # # simple sample batch # dev_src_batch=torch.unsqueeze(dev_src_batch, 0) # dev_tgt_batch=torch.unsqueeze(dev_tgt_batch, 0) dev_src_linear, dev_tgt_linear = input_encoder( dev_src_batch, dev_tgt_batch) log_prob = inter_atten(dev_src_linear, dev_tgt_linear) _, predict = log_prob.data.max(dim=1) total += dev_lbl_batch.data.size()[0] correct += torch.sum(predict == dev_lbl_batch.data) dev_acc = correct.cpu().data.numpy() / total logger.info('dev-acc %.3f' % (dev_acc)) if (k + 1) / args.dev_interval == 1: model_fname = '%s%s_epoch-%d_dev-acc-%.3f' % ( args.model_path, args.log_fname.split('.')[0], k, dev_acc) torch.save(input_encoder.state_dict(), model_fname + '_input-encoder.pt') torch.save(inter_atten.state_dict(), model_fname + '_inter-atten.pt') best_dev.append((k, dev_acc, model_fname)) logger.info('current best-dev:') for t in best_dev: logger.info('\t%d %.3f' % (t[0], t[1])) logger.info('save model!') else: if dev_acc > best_dev[-1][1]: model_fname = '%s%s_epoch-%d_dev-acc-%.3f' % ( args.model_path, args.log_fname.split('.')[0], k, dev_acc) torch.save(input_encoder.state_dict(), model_fname + '_input-encoder.pt') torch.save(inter_atten.state_dict(), model_fname + '_inter-atten.pt') best_dev.append((k, dev_acc, model_fname)) logger.info('current best-dev:') for t in best_dev: logger.info('\t%d %.3f' % (t[0], t[1])) logger.info('save model!') input_encoder.train() inter_atten.train() logger.info('training end!')