def train_vae(): batch_size = 64 epochs = 1000 latent_dimension = 100 patience = 10 device = torch.device( 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') # load data train_loader, valid_loader, _ = get_data_loader('data', batch_size) model = VAE(latent_dimension).to(device) optim = Adam(model.parameters(), lr=1e-3) val_greater_count = 0 last_val_loss = 0 for e in range(epochs): running_loss = 0 model.train() for i, (images, _) in enumerate(train_loader): images = images.to(device) model.zero_grad() outputs, mu, logvar = model(images) loss = compute_loss(images, outputs, mu, logvar) running_loss += loss loss.backward() optim.step() running_loss = running_loss / len(train_loader) model.eval() with torch.no_grad(): val_loss = 0 for images, _ in valid_loader: images = images.to(device) outputs, mu, logvar = model(images) loss = compute_loss(images, outputs, mu, logvar) val_loss += loss val_loss /= len(valid_loader) if val_loss > last_val_loss: val_greater_count += 1 else: val_greater_count = 0 last_val_loss = val_loss torch.save( { 'epoch': e, 'model': model.state_dict(), 'running_loss': running_loss, 'optim': optim.state_dict(), }, "vae/upsample_checkpoint_{}.pth".format(e)) print("Epoch: {} Train Loss: {}".format(e + 1, running_loss.item())) print("Epoch: {} Val Loss: {}".format(e + 1, val_loss.item())) if val_greater_count >= patience: break
def loss_function(recon_x, x, mu, logvar): BCE = reconstruction_function(recon_x, x) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) return BCE + KLD optimizer = optim.Adam(model.parameters(), lr=args.lr) def train(epoch): model.train() train_loss = 0 for batch_idx, data in enumerate(train_loader): data = Variable(data) # print(data.size()) if args.cuda: data = data.cuda() optimizer.zero_grad() recon_batch, mu, logvar = model(data) recon_batch = recon_batch.view(-1, 1, 32, 32) loss = loss_function(recon_batch, data, mu, logvar) loss.backward()
def train(config): # Print all configs to confirm parameter settings print_flags() # Initialize the model that we are going to use # model = LSTMLM(vocabulary_size=vocab_size, model = VAE(vocabulary_size=vocab_size, dropout=1 - config.dropout_keep_prob, lstm_num_hidden=config.lstm_num_hidden, lstm_num_layers=config.lstm_num_layers, lstm_num_direction=config.lstm_num_direction, num_latent=config.num_latent, device=device) model.to(device) # Setup the loss and optimizer criterion = nn.CrossEntropyLoss(ignore_index=1, reduction='sum') optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) # Store some measures iteration = list() tmp_loss = list() train_loss = list() val_nll = list() val_perp = list() val_acc = list() val_elbo = list() train_perp = list() train_acc = list() train_elbo = list() train_nll = list() iter_i = 0 best_perp = 1e6 while True: # when we run out of examples, shuffle and continue for train_batch in get_minibatch(train_data, batch_size=config.batch_size): # Only for time measurement of step through network t1 = time.time() iter_i += 1 model.train() optimizer.zero_grad() inputs, targets, lengths_in_batch = prepare_minibatch( train_batch, vocab) # zeros in dim = (num_layer*num_direction * batch * lstm_hidden_size) # we have bidrectional single layer LSTM h_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, inputs.shape[0], config.lstm_num_hidden).to(device) c_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, inputs.shape[0], config.lstm_num_hidden).to(device) # pred, _, _ = model(inputs, h_0, c_0) decoder_output, KL_loss = model(inputs, h_0, c_0, lengths_in_batch, config.importance_sampling_size) reconstruction_loss = 0.0 for k in range(config.importance_sampling_size): # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output (batch, sent_length, vocab_classes) # decoder_output[k] =decoder_output[k].permute(0, 2, 1) doesnt work reconstruction_loss += criterion( decoder_output[k].permute(0, 2, 1), targets) # get the mean of the k samples of z reconstruction_loss = reconstruction_loss / config.importance_sampling_size KL_loss = KL_loss / config.importance_sampling_size print('At iter', iter_i, ', rc_loss=', reconstruction_loss.item(), ' KL_loss = ', KL_loss.item()) total_loss = (reconstruction_loss + KL_loss) / config.batch_size tmp_loss.append(total_loss.item()) total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_norm) optimizer.step() if iter_i % config.eval_every == 0: eval_data = val_data eval_data_flag = 'val' print('Evaluating with validation at iteration ', iter_i, '...') if iter_i % config.eval_every_train == 0: eval_data = train_data eval_data_flag = 'train' print('Evaluating with training instead, at iteration ', iter_i, '...') model.eval() ppl_total = 0.0 validation_elbo_loss = 0.0 validation_lengths = list() nll_per_eval = list() match = list() with torch.no_grad(): # computing ppl, match, and accuracy for validation_th, val_sen in enumerate(eval_data): val_input, val_target = prepare_example(val_sen, vocab) # zeros in dim = (num_layer*num_direction, # batch=config.importance_sampling_size, lstm_hidden_size) h_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) c_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) # append the sent length of this particular validation example validation_lengths.append(val_input.size(1)) # feed into models decoder_output, KL_loss_validation = model( val_input, h_0, c_0, [val_input.size(1)], config.importance_sampling_size) # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize) # prediction.size() = (k, sent_len, vocabsize) # prediction_mean.size() = (sent_len, vocabsize), ie averaged over # k samples (and squeezed) prediction = nn.functional.softmax(torch.squeeze( decoder_output, dim=1), dim=2) prediction_mean = torch.mean(prediction, 0) # averaged over k ppl_per_example = 0.0 # sentence length, ie 1 word/1 timestamp for each loop for j in range(prediction.shape[1]): # 0 as the target is the same for the k samples ppl_per_example -= torch.log( prediction_mean[j][int(val_target[0][j])]) ppl_total += ppl_per_example if validation_th % 300 == 0: print(' ppl_per_example at the ', validation_th, eval_data_flag, 'case = ', ppl_per_example) tmp_match = compute_match_vae(prediction_mean, val_target) match.append(tmp_match) # calculate validation elbo # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize) # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output to get (k, batchsize=1, vocab_classes, sent_length) # then we loop over k to get (1, vocab_classes, sent_len) decoder_output_validation = decoder_output.permute( 0, 1, 3, 2) reconstruction_loss = 0 for k in range(config.importance_sampling_size): reconstruction_loss += criterion( decoder_output_validation[k], val_target) validation_elbo_loss += (reconstruction_loss + \ KL_loss_validation) / config.importance_sampling_size nll_per_eval.append(ppl_per_example) ppl_total = torch.exp(ppl_total / sum(validation_lengths)) print('ppl_total for iteration ', iter_i, ' = ', ppl_total) accuracy = sum(match) / sum(validation_lengths) print('accuracy for iteration ', iter_i, ' = ', accuracy) # loss of the previous iterations (up the after last eval) avg_loss = sum(tmp_loss) / len(tmp_loss) tmp_loss = list() # reinitialize to zero validation_elbo_loss = validation_elbo_loss / len(val_data) if ppl_total < best_perp: best_perp = ppl_total torch.save(model.state_dict(), "./models/vae_best.pt") # Instead of rewriting the same file, we can have new ones: # model_saved_name = datetime.now().strftime("%Y-%m-%d_%H%M") + './models/vae_best.pt' # torch.save(model.state_dict(), model_saved_name) nll = sum(nll_per_eval) print( "[{}] Train Step {:04d}/{:04d}, " "Validation Perplexity = {:.4f}, Validation loss ={:.4f}, Training Loss = {:.4f}, NLL = {:.4f}" "Validation Accuracy = {:.4f}".format( datetime.now().strftime("%Y-%m-%d %H:%M"), iter_i, config.train_steps, ppl_total, validation_elbo_loss, avg_loss, nll, accuracy)) # update/save eval results everytime iteration.append(iter_i) train_loss.append(avg_loss) np.save('./np_saved_results/train_loss.npy', train_loss + ['till_iter_' + str(iter_i)]) if eval_data_flag == 'val': val_perp.append(ppl_total.item()) val_acc.append(accuracy) val_elbo.append(validation_elbo_loss.item()) val_nll.append(nll) np.save('./np_saved_results/val_perp.npy', val_perp + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/val_acc.npy', val_acc + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/val_elbo.npy', val_elbo + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/val_nll.npy', val_elbo + ['till_iter_' + str(iter_i)]) if eval_data_flag == 'train': train_perp.append(ppl_total.item()) train_acc.append(accuracy) train_elbo.append(validation_elbo_loss.item()) train_nll.append(nll) np.save('./np_saved_results/train_perp.npy', train_perp + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/train_acc.npy', train_acc + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/train_elbo.npy', train_elbo + ['till_iter_' + str(iter_i)]) np.save('./np_saved_results/train_nll.npy', train_elbo + ['till_iter_' + str(iter_i)]) if iter_i == config.train_steps: break if iter_i == config.train_steps: break print('Done training!') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') print('Testing...') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') model.load_state_dict(torch.load('./models/vae_best.pt')) model.eval() ppl_total = 0.0 validation_elbo_loss = 0.0 validation_lengths = list() nll_per_eval = list() match = list() with torch.no_grad(): # computing ppl, match, and accuracy # too large too slow lets stick with first 1000/1700 first for validation_th, val_sen in enumerate(test_data): val_input, val_target = prepare_example(val_sen, vocab) # zeros in dim = (num_layer*num_direction, # batch=config.importance_sampling_size, lstm_hidden_size) h_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) c_0 = torch.zeros( config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) # append the sent length of this particular validation example validation_lengths.append(val_input.size(1)) # feed into models decoder_output, KL_loss_validation = model( val_input, h_0, c_0, [val_input.size(1)], config.importance_sampling_size) # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize) # prediction.size() = (k, sent_len, vocabsize) # prediction_mean.size() = (sent_len, vocabsize), ie averaged over k # samples (and squeezed) prediction = nn.functional.softmax(torch.squeeze(decoder_output, dim=1), dim=2) prediction_mean = torch.mean(prediction, 0) # averaged over k ppl_per_example = 0.0 # sentence length, ie 1 word/1 timestamp for each loop for j in range(prediction.shape[1]): # 0 as the target is the same for the k samples ppl_per_example -= torch.log(prediction_mean[j][int( val_target[0][j])]) ppl_total += ppl_per_example tmp_match = compute_match_vae(prediction_mean, val_target) match.append(tmp_match) # calculate validation elbo # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize) # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output to get (k, batchsize=1, vocab_classes, sent_length) # then we loop over k to get (1, vocab_classes, sent_len) decoder_output_validation = decoder_output.permute(0, 1, 3, 2) reconstruction_loss = 0 for k in range(config.importance_sampling_size): reconstruction_loss += criterion(decoder_output_validation[k], val_target) validation_elbo_loss += (reconstruction_loss + \ KL_loss_validation) / config.importance_sampling_size nll_per_eval.append(ppl_total) ppl_total = torch.exp(ppl_total / sum(validation_lengths)) accuracy = sum(match) / sum(validation_lengths) validation_elbo_loss = validation_elbo_loss / len(test_data) nll = sum(nll_per_eval) print('Test Perplexity on the best model is: {:.3f}'.format(ppl_total)) print( 'Test ELBO on the best model is: {:.3f}'.format(validation_elbo_loss)) print('Test accuracy on the best model is: {:.3f}'.format(accuracy)) print('Test NLL on the best model is: {:.3f}'.format(nll)) print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') with open('./result/vae_test.txt', 'a') as file: file.write( 'Learning Rate = {}, Train Step = {}, ' 'Dropout = {}, LSTM Layers = {}, ' 'Hidden Size = {}, Test Perplexity = {:.3f}, Test ELBO = {:.3f}, Test NLL = {:.3f}' 'Test Accuracy = {}\n'.format(config.learning_rate, config.train_steps, 1 - config.dropout_keep_prob, config.lstm_num_layers, config.lstm_num_hidden, ppl_total, validation_elbo_loss, nll, accuracy)) file.close() print('Sampling...') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') # model.load_state_dict(torch.load('./models/vae_best_lisa.pt')) model.load_state_dict( torch.load('./models/vae_best_lisa.pt', map_location=lambda storage, loc: storage)) with torch.no_grad(): sentences = model.sample(config.sample_size, vocab) sentences_pruned_EOS = [[] for x in range(config.sample_size)] for i in range(len(sentences)): for j in range(len(sentences[i])): if sentences[i][j] != 'EOS': sentences_pruned_EOS[i].append(sentences[i][j]) else: break with open('./result/vae_test_greedy_new.txt', 'a') as file: for idx, sen in enumerate(sentences_pruned_EOS): if idx == 0: file.write('\n Greedy: \n') file.write('Sampling \n{}: {}\n'.format(idx, ' '.join(sen))) else: file.write('Sampling \n{}: {}\n'.format(idx, ' '.join(sen))) print('Interpolating...') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') #interpolation with torch.no_grad(): sentences = model.interpolation(vocab) sentences_pruned_EOS = [[], [], [], [], []] for i in range(len(sentences)): for j in range(len(sentences[i])): if sentences[i][j] != 'EOS': sentences_pruned_EOS[i].append(sentences[i][j]) else: break with open('./result/vae_test_interpolate.txt', 'a') as file: file.write('\n Interpolation: \n') file.write('Sampling z1:\n {}\n'.format(' '.join( sentences_pruned_EOS[0]))) file.write('Sampling z2:\n {}\n'.format(' '.join( sentences_pruned_EOS[1]))) file.write('Sampling z1+z2/2:\n {}\n'.format(' '.join( sentences_pruned_EOS[2]))) file.write('Sampling z1*0.8+z2*0.2:\n {}\n'.format(' '.join( sentences_pruned_EOS[3]))) file.write('Sampling z1*0.2+z2*0.8:\n {}\n'.format(' '.join( sentences_pruned_EOS[4]))) print('Test case reconstruction...') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') test_sen = test_data[101] # print('test_sen', test_sen) test_input, _ = prepare_example(test_sen, vocab) # print('test_input',test_input) # zeros in dim = (num_layer*num_direction, # batch=config.importance_sampling_size, lstm_hidden_size) h_0 = torch.zeros(config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) c_0 = torch.zeros(config.lstm_num_layers * config.lstm_num_direction, config.importance_sampling_size, config.lstm_num_hidden).to(device) # feed into models reconstructed_sentences = model.test_reconstruction(test_input, vocab) sentences_pruned_EOS = [[] for x in range(10)] for i in range(len(reconstructed_sentences)): for j in range(len(reconstructed_sentences[i])): if reconstructed_sentences[i][j] != 'EOS': sentences_pruned_EOS[i].append(reconstructed_sentences[i][j]) else: break with open('./result/vae_test_reconstruct.txt', 'a') as file: file.write('\n The sentence to reconstruct:\n {}\n'.format(' '.join( test_sen[1:]))) for x in range(10): file.write('Sample: {} \n {}\n'.format( x, ' '.join(sentences_pruned_EOS[x]))) ''' t_loss = plt.figure(figsize = (6, 4)) plt.plot(iteration, train_loss) plt.xlabel('Iteration') plt.ylabel('Training Loss') t_loss.tight_layout() t_loss.savefig('./result/vae_training_loss.eps', format='eps') v_perp = plt.figure(figsize = (6, 4)) plt.plot(iteration, val_perp) plt.xlabel('Iteration') plt.ylabel('Validation Perplexity') v_perp.tight_layout() v_perp.savefig('./result/vae_validation_perplexity.eps', format='eps') v_acc = plt.figure(figsize = (6, 4)) plt.plot(iteration, val_acc) plt.xlabel('Iteration') plt.ylabel('Validation Accuracy') v_acc.tight_layout() v_acc.savefig('./result/vae_validation_accuracy.eps', format='eps') v_elbo = plt.figure(figsize = (6, 4)) plt.plot(iteration, val_elbo) plt.xlabel('Iteration') plt.ylabel('Validation ELBO') v_elbo.tight_layout() v_elbo.savefig('./result/vae_validation_elbo.eps', format='eps') print('Figures are saved.') print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-') ''' return 0
data_path = '/home/lilioo826/hw4_data/' train_faceDataset = FaceDataset(data_path + 'train', data_path + 'train.csv', transforms.ToTensor()) train_dataloader = DataLoader(train_faceDataset, batch_size=20, num_workers=1) cuda = True model = VAE(64, 1e-6) # print(model) if cuda: model.cuda() # summary(model, (3,64,64)) # exit() epoch_num = 100 model.train() optimizer = optim.Adam(model.parameters(), lr=1e-4) klds = [] mses = [] for epoch in range(epoch_num): print('epoch {}'.format(epoch + 1)) epoch_kld = 0 epoch_mse = 0 epoch_loss = 0 for batch_idx, (data, label) in enumerate(train_dataloader): if cuda: data = data.cuda() data = Variable(data) optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = model.loss_function(data, recon_batch, mu, logvar)