def train(lr): with open(vocab_freq_file,'r') as f: vocab_freq=pickle.load(f) vocab_p = Q_w(vocab_freq,alpha) J,q=alias_setup(vocab_p) # Load data print 'loading dataset...' train_data=TextIterator(train_datafile,n_batch=n_batch,maxlen=maxlen) valid_data = TextIterator(valid_datafile,n_batch=n_batch,maxlen=maxlen) test_data=TextIterator(test_datafile,n_batch=n_batch,maxlen=maxlen) print 'building model...' model=RNNLM(n_input,n_hidden,vocabulary_size, cell=rnn_cell,optimizer=optimizer,p=p,q_w=vocab_p,k=k) if os.path.isfile(model_dir): print 'loading checkpoint parameters....',model_dir model=load_model(model_dir,model) if goto_line>0: train_data.goto_line(goto_line) print 'goto line:',goto_line print 'training start...' start=time.time() idx = 0 for epoch in xrange(NEPOCH): error = 0 for x,x_mask,y,y_mask in train_data: idx+=1 negy=negative_sample(y,y_mask,k,J,q) cost=model.train(x,x_mask, y, negy,y_mask,lr) #print cost error+=cost if np.isnan(cost) or np.isinf(cost): print 'NaN Or Inf detected!' return -1 if idx % disp_freq==0: logger.info('epoch: %d idx: %d cost: %f ppl: %f' % ( epoch, idx, (error / disp_freq), np.exp(error / (1.0 * disp_freq)))) error=0 if idx%save_freq==0: logger.info( 'dumping...') save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model) if idx % valid_freq==0 : logger.info('validing...') valid_cost,wer=evaluate(valid_data,model) logger.info('validation cost: %f perplexity: %f,word_error_rate:%f' % (valid_cost, np.exp(valid_cost), wer)) if idx % test_freq==0 : logger.info('testing...') test_cost,wer=evaluate(test_data,model) logger.info('test cost: %f perplexity: %f,word_error_rate:%f' % (test_cost, np.exp(test_cost),wer)) print "Finished. Time = "+str(time.time()-start)
def train(config, sw): # Initialize the device which to run the model on device = torch.device(config.device) vocab = torchtext.vocab.FastText() #vocab = torchtext.vocab.GloVe() # get data iterators lm_iters, s_iters = load_data(embeddings=vocab, device=device, batch_size=config.batch_size, bptt_len=config.seq_len) _, valid_iter, test_iter, field = s_iters vocab = field.vocab if config.use_bptt: train_iter, _, _, _ = lm_iters else: train_iter, _, _, _ = s_iters print("Vocab size: {}".format(vocab.vectors.shape)) # create embedding layer embedding = nn.Embedding.from_pretrained(vocab.vectors).to(device) EMBED_DIM = 300 num_classes = vocab.vectors.shape[0] # Initialize the model that we are going to use if config.model == "rnnlm": model = RNNLM(EMBED_DIM, config.hidden_dim, num_classes) elif config.model == "s-vae": model = SentenceVAE(EMBED_DIM, config.hidden_dim, num_classes, fb_lambda=config.freebits_lambda, wd_keep_prob=config.wdropout_prob, wd_unk=embedding( torch.LongTensor([vocab.stoi["<unk>"] ]).to(device)), mu_f_beta=config.mu_forcing_beta) else: raise Error("Invalid model parameter.") model = model.to(device) # Setup the loss, optimizer, lr-scheduler optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) criterion = torch.nn.NLLLoss(reduction="sum").to(config.device) scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.learning_rate_decay) lr = config.learning_rate global_step = 0 best_nll = sys.maxsize best_pp = sys.maxsize best_kl = None for epoch in itertools.count(): for batch in train_iter: # [1] Get data if config.use_bptt: batch_text = batch.text batch_target = batch.target txt_len = torch.full((batch_text.shape[1], ), batch_text.shape[0], device=device) tgt_len = txt_len else: batch_text, txt_len = batch.text batch_target, tgt_len = batch.target batch_text = embedding(batch_text.to(device)) batch_target = batch_target.to(device) # [2] Forward & Loss batch_output = model(batch_text, txt_len) # merge batch and sequence dimension for evaluation batch_output = batch_output.view(-1, batch_output.shape[2]) batch_target = batch_target.view(-1) B = batch_text.shape[1] nll = criterion(batch_output, batch_target) / B sw.add_scalar('Train/NLL', nll.item(), global_step) loss = nll.clone() for loss_name, additional_loss in model.get_additional_losses( ).items(): loss += additional_loss sw.add_scalar('Train/' + loss_name, additional_loss, global_step) # [3] Optimize optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_norm) optimizer.step() sw.add_scalar('Train/Loss', loss.item(), global_step) if global_step % config.print_every == 0: print("[{}] Train Step {:04d}/{:04d}, " "NLL = {:.2f}, Loss = {:.3f}".format( datetime.now().strftime("%Y-%m-%d %H:%M"), global_step, config.train_steps, nll.item(), loss.item()), flush=True) global_step += 1 epoch_nll, epoch_pp, epoch_kl, additional_losses = test_model( model, embedding, criterion, valid_iter, device) model.train() print("Valid NLL: {}".format(epoch_nll)) print("Valid Perplexity: {}".format(epoch_pp)) print("Valid KL: {}".format(epoch_kl)) sw.add_scalar('Valid/NLL', epoch_nll, global_step) sw.add_scalar('Valid/Perplexity', epoch_pp, global_step) sw.add_scalar('Valid/KL', epoch_kl, global_step) # the additional_loss below will also have kl but not multisample for loss_name, additional_loss in additional_losses.items(): sw.add_scalar('Valid/' + loss_name, additional_loss, global_step) # sample some sentences MAX_LEN = 50 for _ in range(5): text = model.temperature_sample(embedding, MAX_LEN) text = ' '.join(vocab.itos[w] for w in text) print(text) sw.add_text('Valid/Sample-text', text, global_step) if epoch_nll < best_nll: best_nll = epoch_nll save_model("best", model, config) if epoch_pp < best_pp: best_pp = epoch_pp if global_step >= config.train_steps: break scheduler.step() print("Learning Rate: {}".format( [group['lr'] for group in optimizer.param_groups])) print('Done training.') best_model = load_model("best", config) test_nll, test_pp, test_kl, test_additional_losses = test_model( best_model, embedding, criterion, test_iter, device) print("Test NLL: {}".format(test_nll)) print("Test PP: {}".format(test_pp)) print("Test KL: {}".format(test_kl)) print("{}".format(test_additional_losses)) return best_model, model, {'hparam/nll': best_nll, 'hparam/pp': best_pp}
def train(lr): # Load data logger.info('loading dataset...') train_data = TextIterator(train_datafile, filepath, n_batch=n_batch, brown_or_huffman=brown_or_huffman, mode=matrix_or_vector, word2idx_path=word2idx_path) valid_data = TextIterator(valid_datafile, filepath, n_batch=n_batch, brown_or_huffman=brown_or_huffman, mode=matrix_or_vector, word2idx_path=word2idx_path) test_data = TextIterator(test_datafile, filepath, n_batch=n_batch, brown_or_huffman=brown_or_huffman, mode=matrix_or_vector, word2idx_path=word2idx_path) logger.info('building model...') model = RNNLM(n_input, n_hidden, vocabulary_size, cell, optimizer, p=p, mode=matrix_or_vector) if os.path.exists(model_dir) and reload_dumps == 1: logger.info('loading parameters from: %s' % model_dir) model = load_model(model_dir, model) else: logger.info("init parameters....") logger.info('training start...') start = time.time() idx = 0 for epoch in xrange(NEPOCH): error = 0 for x, x_mask, (y_node, y_choice, y_bit_mask), y_mask in train_data: idx += 1 cost = model.train(x, x_mask, y_node, y_choice, y_bit_mask, y_mask, lr) error += cost if np.isnan(cost) or np.isinf(cost): print 'NaN Or Inf detected!' return -1 if idx % disp_freq == 0: logger.info('epoch: %d idx: %d cost: %f ppl: %f' % (epoch, idx, error / disp_freq, np.exp(error / (1.0 * disp_freq)))) #,'lr:',lr error = 0 if idx % save_freq == 0: logger.info('dumping...') save_model( './model/parameters_%.2f.pkl' % (time.time() - start), model) if idx % valid_freq == 0: logger.info('validing....') valid_cost = evaluate(valid_data, model) logger.info('valid_cost: %f perplexity: %f' % (valid_cost, np.exp(valid_cost))) if idx % test_freq == 0: logger.info('testing...') test_cost = evaluate(test_data, model) logger.info('test cost: %f perplexity: %f' % (test_cost, np.exp(test_cost))) #if idx%clip_freq==0 and lr >=0.01: # print 'cliping learning rate:', # lr=lr*0.9 # print lr sys.stdout.flush() print "Finished. Time = " + str(time.time() - start)