Esempio n. 1
0
def train_vae():

    batch_size = 64
    epochs = 1000
    latent_dimension = 100
    patience = 10

    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    # load data
    train_loader, valid_loader, _ = get_data_loader('data', batch_size)

    model = VAE(latent_dimension).to(device)

    optim = Adam(model.parameters(), lr=1e-3)

    val_greater_count = 0
    last_val_loss = 0
    for e in range(epochs):
        running_loss = 0
        model.train()
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            model.zero_grad()
            outputs, mu, logvar = model(images)
            loss = compute_loss(images, outputs, mu, logvar)
            running_loss += loss
            loss.backward()
            optim.step()

        running_loss = running_loss / len(train_loader)
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for images, _ in valid_loader:
                images = images.to(device)
                outputs, mu, logvar = model(images)
                loss = compute_loss(images, outputs, mu, logvar)
                val_loss += loss
            val_loss /= len(valid_loader)

        if val_loss > last_val_loss:
            val_greater_count += 1
        else:
            val_greater_count = 0
        last_val_loss = val_loss

        torch.save(
            {
                'epoch': e,
                'model': model.state_dict(),
                'running_loss': running_loss,
                'optim': optim.state_dict(),
            }, "vae/upsample_checkpoint_{}.pth".format(e))
        print("Epoch: {} Train Loss: {}".format(e + 1, running_loss.item()))
        print("Epoch: {} Val Loss: {}".format(e + 1, val_loss.item()))
        if val_greater_count >= patience:
            break
Esempio n. 2
0

def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=args.lr)


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = Variable(data)
        # print(data.size())
        if args.cuda:
            data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        recon_batch = recon_batch.view(-1, 1, 32, 32)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
Esempio n. 3
0
def train(config):
    # Print all configs to confirm parameter settings
    print_flags()

    # Initialize the model that we are going to use
    # model = LSTMLM(vocabulary_size=vocab_size,
    model = VAE(vocabulary_size=vocab_size,
                dropout=1 - config.dropout_keep_prob,
                lstm_num_hidden=config.lstm_num_hidden,
                lstm_num_layers=config.lstm_num_layers,
                lstm_num_direction=config.lstm_num_direction,
                num_latent=config.num_latent,
                device=device)

    model.to(device)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=1, reduction='sum')
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    # Store some measures
    iteration = list()
    tmp_loss = list()
    train_loss = list()

    val_nll = list()
    val_perp = list()
    val_acc = list()
    val_elbo = list()

    train_perp = list()
    train_acc = list()
    train_elbo = list()
    train_nll = list()

    iter_i = 0
    best_perp = 1e6

    while True:  # when we run out of examples, shuffle and continue
        for train_batch in get_minibatch(train_data,
                                         batch_size=config.batch_size):

            # Only for time measurement of step through network
            t1 = time.time()
            iter_i += 1

            model.train()
            optimizer.zero_grad()

            inputs, targets, lengths_in_batch = prepare_minibatch(
                train_batch, vocab)

            # zeros in dim = (num_layer*num_direction * batch * lstm_hidden_size)
            # we have bidrectional single layer LSTM
            h_0 = torch.zeros(
                config.lstm_num_layers * config.lstm_num_direction,
                inputs.shape[0], config.lstm_num_hidden).to(device)
            c_0 = torch.zeros(
                config.lstm_num_layers * config.lstm_num_direction,
                inputs.shape[0], config.lstm_num_hidden).to(device)

            # pred, _, _ = model(inputs, h_0, c_0)
            decoder_output, KL_loss = model(inputs, h_0, c_0, lengths_in_batch,
                                            config.importance_sampling_size)

            reconstruction_loss = 0.0

            for k in range(config.importance_sampling_size):
                # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output (batch, sent_length, vocab_classes)
                # decoder_output[k] =decoder_output[k].permute(0, 2, 1) doesnt work
                reconstruction_loss += criterion(
                    decoder_output[k].permute(0, 2, 1), targets)

            # get the mean of the k samples of z
            reconstruction_loss = reconstruction_loss / config.importance_sampling_size
            KL_loss = KL_loss / config.importance_sampling_size

            print('At iter', iter_i, ', rc_loss=', reconstruction_loss.item(),
                  ' KL_loss = ', KL_loss.item())

            total_loss = (reconstruction_loss + KL_loss) / config.batch_size
            tmp_loss.append(total_loss.item())
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=config.max_norm)
            optimizer.step()

            if iter_i % config.eval_every == 0:
                eval_data = val_data
                eval_data_flag = 'val'
                print('Evaluating with validation at iteration ', iter_i,
                      '...')

                if iter_i % config.eval_every_train == 0:
                    eval_data = train_data
                    eval_data_flag = 'train'
                    print('Evaluating with training instead, at iteration ',
                          iter_i, '...')

                model.eval()

                ppl_total = 0.0
                validation_elbo_loss = 0.0
                validation_lengths = list()
                nll_per_eval = list()
                match = list()

                with torch.no_grad():
                    # computing ppl, match, and accuracy
                    for validation_th, val_sen in enumerate(eval_data):
                        val_input, val_target = prepare_example(val_sen, vocab)

                        # zeros in dim = (num_layer*num_direction,
                        # batch=config.importance_sampling_size,  lstm_hidden_size)
                        h_0 = torch.zeros(
                            config.lstm_num_layers * config.lstm_num_direction,
                            config.importance_sampling_size,
                            config.lstm_num_hidden).to(device)
                        c_0 = torch.zeros(
                            config.lstm_num_layers * config.lstm_num_direction,
                            config.importance_sampling_size,
                            config.lstm_num_hidden).to(device)

                        # append the sent length of this particular validation example
                        validation_lengths.append(val_input.size(1))

                        # feed into models
                        decoder_output, KL_loss_validation = model(
                            val_input, h_0, c_0, [val_input.size(1)],
                            config.importance_sampling_size)

                        # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize)
                        # prediction.size() = (k, sent_len, vocabsize)
                        # prediction_mean.size() = (sent_len, vocabsize), ie averaged over
                        # k samples (and squeezed)
                        prediction = nn.functional.softmax(torch.squeeze(
                            decoder_output, dim=1),
                                                           dim=2)
                        prediction_mean = torch.mean(prediction,
                                                     0)  # averaged over k

                        ppl_per_example = 0.0
                        # sentence length, ie 1 word/1 timestamp for each loop
                        for j in range(prediction.shape[1]):
                            # 0 as the target is the same for the k samples
                            ppl_per_example -= torch.log(
                                prediction_mean[j][int(val_target[0][j])])

                        ppl_total += ppl_per_example

                        if validation_th % 300 == 0:
                            print('    ppl_per_example at the ', validation_th,
                                  eval_data_flag, 'case = ', ppl_per_example)

                        tmp_match = compute_match_vae(prediction_mean,
                                                      val_target)
                        match.append(tmp_match)

                        # calculate validation elbo
                        # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize)
                        # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output  to get (k, batchsize=1, vocab_classes, sent_length)
                        # then we loop over k to get (1, vocab_classes, sent_len)
                        decoder_output_validation = decoder_output.permute(
                            0, 1, 3, 2)

                        reconstruction_loss = 0

                        for k in range(config.importance_sampling_size):
                            reconstruction_loss += criterion(
                                decoder_output_validation[k], val_target)

                        validation_elbo_loss += (reconstruction_loss + \
                                                 KL_loss_validation) / config.importance_sampling_size

                        nll_per_eval.append(ppl_per_example)

                ppl_total = torch.exp(ppl_total / sum(validation_lengths))
                print('ppl_total for iteration ', iter_i, ' =  ', ppl_total)

                accuracy = sum(match) / sum(validation_lengths)
                print('accuracy for iteration ', iter_i, ' =  ', accuracy)

                # loss of the previous iterations (up the after last eval)
                avg_loss = sum(tmp_loss) / len(tmp_loss)
                tmp_loss = list()  # reinitialize to zero
                validation_elbo_loss = validation_elbo_loss / len(val_data)

                if ppl_total < best_perp:
                    best_perp = ppl_total
                    torch.save(model.state_dict(), "./models/vae_best.pt")

                    # Instead of rewriting the same file, we can have new ones:
                    # model_saved_name = datetime.now().strftime("%Y-%m-%d_%H%M") + './models/vae_best.pt'
                    # torch.save(model.state_dict(), model_saved_name)

                nll = sum(nll_per_eval)

                print(
                    "[{}] Train Step {:04d}/{:04d}, "
                    "Validation Perplexity = {:.4f}, Validation loss ={:.4f}, Training Loss = {:.4f}, NLL = {:.4f}"
                    "Validation Accuracy = {:.4f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"), iter_i,
                        config.train_steps, ppl_total, validation_elbo_loss,
                        avg_loss, nll, accuracy))

                # update/save eval results everytime
                iteration.append(iter_i)
                train_loss.append(avg_loss)
                np.save('./np_saved_results/train_loss.npy',
                        train_loss + ['till_iter_' + str(iter_i)])

                if eval_data_flag == 'val':
                    val_perp.append(ppl_total.item())
                    val_acc.append(accuracy)
                    val_elbo.append(validation_elbo_loss.item())
                    val_nll.append(nll)

                    np.save('./np_saved_results/val_perp.npy',
                            val_perp + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/val_acc.npy',
                            val_acc + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/val_elbo.npy',
                            val_elbo + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/val_nll.npy',
                            val_elbo + ['till_iter_' + str(iter_i)])

                if eval_data_flag == 'train':
                    train_perp.append(ppl_total.item())
                    train_acc.append(accuracy)
                    train_elbo.append(validation_elbo_loss.item())
                    train_nll.append(nll)

                    np.save('./np_saved_results/train_perp.npy',
                            train_perp + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/train_acc.npy',
                            train_acc + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/train_elbo.npy',
                            train_elbo + ['till_iter_' + str(iter_i)])
                    np.save('./np_saved_results/train_nll.npy',
                            train_elbo + ['till_iter_' + str(iter_i)])

                if iter_i == config.train_steps:
                    break

        if iter_i == config.train_steps:
            break

    print('Done training!')
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')

    print('Testing...')
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
    model.load_state_dict(torch.load('./models/vae_best.pt'))
    model.eval()

    ppl_total = 0.0
    validation_elbo_loss = 0.0
    validation_lengths = list()
    nll_per_eval = list()
    match = list()

    with torch.no_grad():
        # computing ppl, match, and accuracy
        # too large too slow lets stick with first 1000/1700 first
        for validation_th, val_sen in enumerate(test_data):
            val_input, val_target = prepare_example(val_sen, vocab)

            # zeros in dim = (num_layer*num_direction,
            # batch=config.importance_sampling_size,  lstm_hidden_size)
            h_0 = torch.zeros(
                config.lstm_num_layers * config.lstm_num_direction,
                config.importance_sampling_size,
                config.lstm_num_hidden).to(device)
            c_0 = torch.zeros(
                config.lstm_num_layers * config.lstm_num_direction,
                config.importance_sampling_size,
                config.lstm_num_hidden).to(device)

            # append the sent length of this particular validation example
            validation_lengths.append(val_input.size(1))

            # feed into models
            decoder_output, KL_loss_validation = model(
                val_input, h_0, c_0, [val_input.size(1)],
                config.importance_sampling_size)

            # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize)
            # prediction.size() = (k, sent_len, vocabsize)
            # prediction_mean.size() = (sent_len, vocabsize), ie averaged over k
            # samples (and squeezed)
            prediction = nn.functional.softmax(torch.squeeze(decoder_output,
                                                             dim=1),
                                               dim=2)
            prediction_mean = torch.mean(prediction, 0)  # averaged over k

            ppl_per_example = 0.0
            # sentence length, ie 1 word/1 timestamp for each loop
            for j in range(prediction.shape[1]):
                # 0 as the target is the same for the k samples
                ppl_per_example -= torch.log(prediction_mean[j][int(
                    val_target[0][j])])

            ppl_total += ppl_per_example

            tmp_match = compute_match_vae(prediction_mean, val_target)
            match.append(tmp_match)

            # calculate validation elbo
            # decoder_output.size() = (k, batchsize=1, val_input.size(1)(ie sent_length), vocabsize)
            # the first argument for criterion, ie, crossEntrooy must be (batch, classes(ie vocab size), sent_length), so we need to permute the last two dimension of decoder_output  to get (k, batchsize=1, vocab_classes, sent_length)
            # then we loop over k to get (1, vocab_classes, sent_len)
            decoder_output_validation = decoder_output.permute(0, 1, 3, 2)

            reconstruction_loss = 0

            for k in range(config.importance_sampling_size):
                reconstruction_loss += criterion(decoder_output_validation[k],
                                                 val_target)

            validation_elbo_loss += (reconstruction_loss + \
                                     KL_loss_validation) / config.importance_sampling_size

            nll_per_eval.append(ppl_total)

    ppl_total = torch.exp(ppl_total / sum(validation_lengths))

    accuracy = sum(match) / sum(validation_lengths)

    validation_elbo_loss = validation_elbo_loss / len(test_data)

    nll = sum(nll_per_eval)

    print('Test Perplexity on the best model is: {:.3f}'.format(ppl_total))
    print(
        'Test ELBO on the best model is: {:.3f}'.format(validation_elbo_loss))
    print('Test accuracy on the best model is: {:.3f}'.format(accuracy))
    print('Test NLL on the best model is: {:.3f}'.format(nll))
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
    with open('./result/vae_test.txt', 'a') as file:
        file.write(
            'Learning Rate = {}, Train Step = {}, '
            'Dropout = {}, LSTM Layers = {}, '
            'Hidden Size = {}, Test Perplexity = {:.3f}, Test ELBO =  {:.3f}, Test NLL =  {:.3f}'
            'Test Accuracy = {}\n'.format(config.learning_rate,
                                          config.train_steps,
                                          1 - config.dropout_keep_prob,
                                          config.lstm_num_layers,
                                          config.lstm_num_hidden, ppl_total,
                                          validation_elbo_loss, nll, accuracy))
        file.close()

    print('Sampling...')
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')

    # model.load_state_dict(torch.load('./models/vae_best_lisa.pt'))
    model.load_state_dict(
        torch.load('./models/vae_best_lisa.pt',
                   map_location=lambda storage, loc: storage))

    with torch.no_grad():
        sentences = model.sample(config.sample_size, vocab)

    sentences_pruned_EOS = [[] for x in range(config.sample_size)]
    for i in range(len(sentences)):
        for j in range(len(sentences[i])):
            if sentences[i][j] != 'EOS':
                sentences_pruned_EOS[i].append(sentences[i][j])
            else:
                break

    with open('./result/vae_test_greedy_new.txt', 'a') as file:
        for idx, sen in enumerate(sentences_pruned_EOS):
            if idx == 0:
                file.write('\n Greedy: \n')
                file.write('Sampling \n{}: {}\n'.format(idx, ' '.join(sen)))
            else:
                file.write('Sampling \n{}: {}\n'.format(idx, ' '.join(sen)))

    print('Interpolating...')
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')

    #interpolation
    with torch.no_grad():
        sentences = model.interpolation(vocab)

    sentences_pruned_EOS = [[], [], [], [], []]
    for i in range(len(sentences)):
        for j in range(len(sentences[i])):
            if sentences[i][j] != 'EOS':
                sentences_pruned_EOS[i].append(sentences[i][j])
            else:
                break

    with open('./result/vae_test_interpolate.txt', 'a') as file:
        file.write('\n Interpolation: \n')
        file.write('Sampling z1:\n {}\n'.format(' '.join(
            sentences_pruned_EOS[0])))
        file.write('Sampling z2:\n {}\n'.format(' '.join(
            sentences_pruned_EOS[1])))
        file.write('Sampling z1+z2/2:\n {}\n'.format(' '.join(
            sentences_pruned_EOS[2])))
        file.write('Sampling z1*0.8+z2*0.2:\n {}\n'.format(' '.join(
            sentences_pruned_EOS[3])))
        file.write('Sampling z1*0.2+z2*0.8:\n {}\n'.format(' '.join(
            sentences_pruned_EOS[4])))

    print('Test case reconstruction...')
    print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
    test_sen = test_data[101]
    # print('test_sen', test_sen)
    test_input, _ = prepare_example(test_sen, vocab)
    # print('test_input',test_input)

    # zeros in dim = (num_layer*num_direction,
    # batch=config.importance_sampling_size,  lstm_hidden_size)
    h_0 = torch.zeros(config.lstm_num_layers * config.lstm_num_direction,
                      config.importance_sampling_size,
                      config.lstm_num_hidden).to(device)
    c_0 = torch.zeros(config.lstm_num_layers * config.lstm_num_direction,
                      config.importance_sampling_size,
                      config.lstm_num_hidden).to(device)

    # feed into models
    reconstructed_sentences = model.test_reconstruction(test_input, vocab)

    sentences_pruned_EOS = [[] for x in range(10)]
    for i in range(len(reconstructed_sentences)):
        for j in range(len(reconstructed_sentences[i])):
            if reconstructed_sentences[i][j] != 'EOS':
                sentences_pruned_EOS[i].append(reconstructed_sentences[i][j])
            else:
                break

    with open('./result/vae_test_reconstruct.txt', 'a') as file:
        file.write('\n The sentence to reconstruct:\n {}\n'.format(' '.join(
            test_sen[1:])))
        for x in range(10):
            file.write('Sample: {} \n {}\n'.format(
                x, ' '.join(sentences_pruned_EOS[x])))
    '''
  t_loss = plt.figure(figsize = (6, 4))
  plt.plot(iteration, train_loss)
  plt.xlabel('Iteration')
  plt.ylabel('Training Loss')
  t_loss.tight_layout()
  t_loss.savefig('./result/vae_training_loss.eps', format='eps')

  v_perp = plt.figure(figsize = (6, 4))
  plt.plot(iteration, val_perp)
  plt.xlabel('Iteration')
  plt.ylabel('Validation Perplexity')
  v_perp.tight_layout()
  v_perp.savefig('./result/vae_validation_perplexity.eps', format='eps')

  v_acc = plt.figure(figsize = (6, 4))
  plt.plot(iteration, val_acc)
  plt.xlabel('Iteration')
  plt.ylabel('Validation Accuracy')
  v_acc.tight_layout()
  v_acc.savefig('./result/vae_validation_accuracy.eps', format='eps')


  v_elbo = plt.figure(figsize = (6, 4))
  plt.plot(iteration, val_elbo)
  plt.xlabel('Iteration')
  plt.ylabel('Validation ELBO')
  v_elbo.tight_layout()
  v_elbo.savefig('./result/vae_validation_elbo.eps', format='eps')
  print('Figures are saved.')
  print('+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-')
  '''

    return 0
Esempio n. 4
0
    data_path = '/home/lilioo826/hw4_data/'
train_faceDataset = FaceDataset(data_path + 'train', data_path + 'train.csv',
                                transforms.ToTensor())
train_dataloader = DataLoader(train_faceDataset, batch_size=20, num_workers=1)

cuda = True
model = VAE(64, 1e-6)
# print(model)
if cuda:
    model.cuda()
# summary(model, (3,64,64))
# exit()

epoch_num = 100
model.train()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

klds = []
mses = []
for epoch in range(epoch_num):
    print('epoch {}'.format(epoch + 1))
    epoch_kld = 0
    epoch_mse = 0
    epoch_loss = 0
    for batch_idx, (data, label) in enumerate(train_dataloader):
        if cuda:
            data = data.cuda()
        data = Variable(data)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(data, recon_batch, mu, logvar)