コード例 #1
0
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    evaluate_fi = 0
    evaluate_mi = 0
    evaluate_psnr = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data

        # calculate loss function
        loss, RE, KL, FI, MI = model.calculate_loss(x, average=True)

        if args.FI is True:
            FI_, gamma = model.FI(x)
            loss += torch.mean(FI_ * gamma)

        if args.MI is True:
            loss -= args.ksi * torch.mean((MI - args.M).abs())

        if not args.MI is True:
            MI = model.MI(x)

        if not args.FI is True:
            FI, gamma = model.FI(x)
            FI = torch.mean(torch.exp(torch.mean(FI)))

        psnr = model.psnr(x)

        evaluate_loss += loss.data[0]
        evaluate_re += -RE.data[0]
        evaluate_kl += KL.data[0]
        evaluate_fi += FI.data[0]
        evaluate_mi += MI.abs().data[0]
        evaluate_psnr += psnr

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            if epoch == 1:
                if not os.path.exists(dir + 'reconstruction/'):
                    os.makedirs(dir + 'reconstruction/')
                # VISUALIZATION: plot real images
                plot_images(args,
                            data.data.cpu().numpy()[0:9],
                            dir + 'reconstruction/',
                            'real',
                            size_x=3,
                            size_y=3)
            x_mean = model.reconstruct_x(x)
            plot_images(args,
                        x_mean.data.cpu().numpy()[0:9],
                        dir + 'reconstruction/',
                        str(epoch),
                        size_x=3,
                        size_y=3)

    if mode == 'test':
        # load all data
        if args.dataset_name == 'celeba':
            test_data = []
            test_target = []
            full_data = []
            for d, l in data_loader.dataset:
                test_data.append(d)
                test_target.append(l)

            for d, l in train_loader.dataset:
                full_data.append(d)

            test_data = Variable(torch.stack(test_data), volatile=True)
            test_target = Variable(torch.from_numpy(np.array(test_target)),
                                   volatile=True)
            full_data = Variable(torch.stack(full_data[60000]), volatile=True)

        else:
            test_data = Variable(data_loader.dataset.data_tensor,
                                 volatile=True)
            test_target = Variable(data_loader.dataset.target_tensor,
                                   volatile=True)
            full_data = Variable(train_loader.dataset.data_tensor,
                                 volatile=True)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(
            ), test_target.cuda(), full_data.cuda()

        if args.dynamic_binarization:
            full_data = torch.bernoulli(full_data)

        # print(model.means(model.idle_input))

        # VISUALIZATION: plot real images
        #for k in range(200):
        #    plot_images(args, test_data.data.cpu().numpy()[k*25:(k+1)*25], dir, 'real'+str(k), size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions

        if not os.path.exists(dir + 'test_reconstruction/'):
            os.makedirs(dir + 'test_reconstruction/')
        for k in range(200):
            samples = model.reconstruct_x(test_data[k * 25:(k + 1) * 25])
            plot_images(args,
                        samples.data.cpu().numpy(),
                        dir,
                        'test_reconstruction/' + str(k),
                        size_x=5,
                        size_y=5)
        '''    
        # VISUALIZATION: plot real images
        plot_images(args, test_data.data.cpu().numpy()[0:25], dir, 'real', size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        samples = model.reconstruct_x(test_data[0:25])

        plot_images(args, samples.data.cpu().numpy(), dir, 'reconstructions', size_x=5, size_y=5)
        '''
        # VISUALIZATION: plot generations
        samples_rand = model.generate_x(25)

        plot_images(args,
                    samples_rand.data.cpu().numpy(),
                    dir,
                    'generations',
                    size_x=5,
                    size_y=5)

        # VISUALIZATION: plot traversal
        if args.z1_size > 10:
            if args.dataset_name == 'celeba':
                cnt = 0
                for j in range(len(test_target)):
                    if cnt == 10:
                        break
                    samples_rand = model.traversal(test_data[j])
                    plot_images(args,
                                samples_rand.data.cpu().numpy(),
                                dir,
                                'attributes' + str(cnt),
                                size_x=10,
                                size_y=10)
                    cnt += 1

            else:
                for i in range(10):
                    for j in range(len(test_target)):
                        if test_target[j].data.cpu().numpy()[0] % 10 == i:
                            samples_rand = model.traversal(test_data[j])
                            plot_images(args,
                                        samples_rand.data.cpu().numpy(),
                                        dir,
                                        'attributes' + str(i),
                                        size_x=10,
                                        size_y=10)
                            break
        else:
            pass

        # VISUALIZATION: latent space
        if args.latent is True and args.dataset_name != 'celeba':
            z_mean_recon, z_logvar_recon, _ = model.q_z(
                test_data.view(-1, args.input_size[0], args.input_size[1],
                               args.input_size[2]))
            print("latent visualization")
            #plot_scatter(model, test_data.view(-1, args.input_size[0], args.input_size[1], args.input_size[2]), test_target, dir)
            visualize_latent(
                z_mean_recon, test_target, dir + '/latent_' + args.model_name +
                '_' + args.model_signature)

        if args.z1_size == 2:
            # VISUALIZATION: plot low-dimensional manifold
            plot_manifold(model, args, dir)

            # VISUALIZATION: plot scatter-plot
            #plot_scatter(model, test_data.view(-1, args.input_size[0], args.input_size[1], args.input_size[2]), test_target, dir)

        if args.prior == 'vampprior':
            # VISUALIZE pseudoinputs
            pseudoinputs = model.means(model.idle_input).cpu().data.numpy()

            plot_images(args,
                        pseudoinputs[0:25],
                        dir,
                        'pseudoinputs',
                        size_x=5,
                        size_y=5)

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data, MB=args.MB)
        t_ll_e = time.time()
        print('Test lower-bound value {:.2f} in time: {:.2f}s'.format(
            elbo_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        try:
            elbo_train = 0.  #model.calculate_lower_bound(full_data, MB=args.MB)
        except:
            elbo_train = 0.
        t_ll_e = time.time()
        print('Train lower-bound value {:.2f} in time: {:.2f}s'.format(
            elbo_train, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = 0.  #model.calculate_likelihood(test_data, dir, mode='test', S=args.S, MB=args.MB)
        t_ll_e = time.time()
        print('Test log_likelihood value {:.2f} in time: {:.2f}s'.format(
            log_likelihood_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0.  #model.calculate_likelihood(full_data, dir, mode='train', S=args.S, MB=args.MB)) #commented because it takes too much time
        t_ll_e = time.time()
        print('Train log_likelihood value {:.2f} in time: {:.2f}s'.format(
            log_likelihood_train, t_ll_e - t_ll_s))

        t_ll_s = time.time()
        model.calculate_dist(test_data[:int(args.S)], dir, mode='test')
        t_ll_e = time.time()
        print('Test latent distribution in time: {:.2f}s'.format(t_ll_e -
                                                                 t_ll_s))

    # calculate final loss
    evaluate_loss /= len(
        data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    evaluate_fi /= len(data_loader)
    evaluate_mi /= len(data_loader)
    evaluate_psnr /= len(data_loader)
    if mode == 'test':
        #print(model.q_z_layers)
        return evaluate_loss, evaluate_re, evaluate_kl, evaluate_fi, evaluate_mi, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train, evaluate_psnr
    else:
        return evaluate_loss, evaluate_re, evaluate_kl, evaluate_fi, evaluate_mi, evaluate_psnr
コード例 #2
0
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data

        # calculate loss function
        loss, RE, KL = model.calculate_loss(x, average=True)

        evaluate_loss += loss.data[0]
        evaluate_re += -RE.data[0]
        evaluate_kl += KL.data[0]

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            if epoch == 1:
                if not os.path.exists(dir + 'reconstruction/'):
                    os.makedirs(dir + 'reconstruction/')
                # VISUALIZATION: plot real images
                plot_images(args, data.data.cpu().numpy()[0:9], dir + 'reconstruction/', 'real', size_x=3, size_y=3)
            x_mean = model.reconstruct_x(x)
            plot_images(args, x_mean.data.cpu().numpy()[0:9], dir + 'reconstruction/', str(epoch), size_x=3, size_y=3)

    if mode == 'test':
        # load all data
        test_data = Variable(data_loader.dataset.data_tensor)
        test_target = Variable(data_loader.dataset.target_tensor)
        full_data = Variable(train_loader.dataset.data_tensor)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda()

        if args.dynamic_binarization:
            full_data = torch.bernoulli(full_data)

        # print(model.means(model.idle_input))

        # VISUALIZATION: plot real images
        plot_images(args, test_data.data.cpu().numpy()[0:25], dir, 'real', size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        samples = model.reconstruct_x(test_data[0:25])

        plot_images(args, samples.data.cpu().numpy(), dir, 'reconstructions', size_x=5, size_y=5)

        # VISUALIZATION: plot generations
        samples_rand = model.generate_x(25)

        plot_images(args, samples_rand.data.cpu().numpy(), dir, 'generations', size_x=5, size_y=5)

        if args.prior == 'vampprior':
            # VISUALIZE pseudoinputs
            pseudoinputs = model.means(model.idle_input).cpu().data.numpy()

            plot_images(args, pseudoinputs[0:25], dir, 'pseudoinputs', size_x=5, size_y=5)

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data, MB=args.MB)
        t_ll_e = time.time()
        print('Test lower-bound value {:.2f} in time: {:.2f}s'.format(elbo_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        elbo_train = model.calculate_lower_bound(full_data, MB=args.MB)
        t_ll_e = time.time()
        print('Train lower-bound value {:.2f} in time: {:.2f}s'.format(elbo_train, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test', S=args.S, MB=args.MB)
        t_ll_e = time.time()
        print('Test log_likelihood value {:.2f} in time: {:.2f}s'.format(log_likelihood_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train', S=args.S, MB=args.MB)) #commented because it takes too much time
        t_ll_e = time.time()
        print('Train log_likelihood value {:.2f} in time: {:.2f}s'.format(log_likelihood_train, t_ll_e - t_ll_s))

    # calculate final loss
    evaluate_loss /= len(data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    if mode == 'test':
        return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train
    else:
        return evaluate_loss, evaluate_re, evaluate_kl
コード例 #3
0
ファイル: evaluation.py プロジェクト: sunsunyyl/SWAE
def evaluate_vae(args, model, train_loader, data_loader, epoch, results_dir,
                 mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    evaluate_re_l2square = 0
    # set model to evaluation mode
    model.eval()

    if args.number_components_input >= 100:

        #M = N = 8
        M = 10
        N = 5
    else:
        N = min(8, args.number_components_input)
        M = max(1, args.number_components_input // N)

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()

        x = data

        # calculate loss function
        loss, RE, KL = model.calculate_loss(x, average=True)

        # calcualte reconstruction loss L2 norm square
        recon_loss = model.recon_loss(x)

        evaluate_loss += loss.item()
        evaluate_re += -RE.item()
        evaluate_kl += KL.item()
        evaluate_re_l2square += recon_loss.item()

        # print N digits
        if batch_idx == 0 and mode == 'validation':

            log_step = 50

            if epoch == 1:
                if not os.path.exists(results_dir + 'validation/'):
                    os.makedirs(results_dir + 'validation/')
                # VISUALIZATION: plot real images
                plot_images(args,
                            data.data.cpu().numpy()[0:N * M],
                            results_dir + 'validation/',
                            'real_x',
                            size_x=N,
                            size_y=M)

            if epoch % log_step == 0:
                x_mean = model.reconstruct_x(x)
                plot_images(args,
                            x_mean.data.cpu().numpy()[0:N * M],
                            results_dir + 'validation/',
                            'recon_x_epoch' + str(epoch),
                            size_x=N,
                            size_y=M)

    if mode == 'test':
        # load all data
        # grab the test data by iterating over the loader
        # there is no standardized tensor_dataset member across pytorch datasets
        test_data, test_target = [], []
        for data, lbls in data_loader:
            test_data.append(data)
            test_target.append(lbls)

        test_data, test_target = [
            torch.cat(test_data, dim=0),
            torch.cat(test_target, dim=0).squeeze()
        ]

        #test noisy input
        for i in np.linspace(0.1, 0.5, num=5):
            test_noisy_data = test_data[0:N * M] + torch.randn(
                test_data[0:N * M].size()) * i
            if args.cuda:
                test_noisy_data = test_noisy_data.cuda()
            plot_images(args,
                        test_noisy_data.data.cpu().numpy(),
                        results_dir,
                        'real_noisy_x_' + str(round(i, 1)),
                        size_x=N,
                        size_y=M)

            noisy_samples = model.reconstruct_x(test_noisy_data)
            plot_images(args,
                        noisy_samples.data.cpu().numpy(),
                        results_dir,
                        'recon_noisy_x_' + str(round(i, 1)),
                        size_x=N,
                        size_y=M)

        if args.cuda:
            test_data, test_target = test_data.cuda(), test_target.cuda()

        if args.fid_score == True or args.fid_score_rec == True:
            evaluate_fid_score(args, model, results_dir, test_data)

        # VISUALIZATION: plot real images
        plot_images(args,
                    test_data.data.cpu().numpy()[0:N * M],
                    results_dir,
                    'real_x',
                    size_x=N,
                    size_y=M)

        # VISUALIZATION: plot reconstructions
        samples = model.reconstruct_x(test_data[0:N * M])

        plot_images(args,
                    samples.data.cpu().numpy(),
                    results_dir,
                    'recon_x',
                    size_x=N,
                    size_y=M)

        # VISUALIZATION: plot generations
        samples_rand = model.generate_x(N * M)

        plot_images(args,
                    samples_rand.data.cpu().numpy(),
                    results_dir,
                    'gen_x',
                    size_x=N,
                    size_y=M)

        if args.prior == 'vampprior':
            # VISUALIZE pseudoinputs
            pseudoinputs = model.means(model.idle_input).cpu().data.numpy()

            plot_images(args,
                        pseudoinputs[0:N * M],
                        results_dir,
                        'pseudoinputs',
                        size_x=N,
                        size_y=M)

    # calculate final loss
    evaluate_loss /= len(
        data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    evaluate_re_l2square /= len(data_loader)

    return evaluate_loss, evaluate_re, evaluate_kl, evaluate_re_l2square
コード例 #4
0
def evaluate_vae(args, model, train_loader, data_loader, epoch, dir, mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_re = 0
    evaluate_kl = 0
    # set model to evaluation mode
    model.eval()

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        x = data

        # calculate loss function
        loss, RE, KL = model.calculate_loss(x, average=True)

        evaluate_loss += loss.data[0]
        evaluate_re += -RE.data[0]
        evaluate_kl += KL.data[0]

        # print N digits
        if batch_idx == 1 and mode == 'validation':
            if epoch == 1:
                if not os.path.exists(dir + 'reconstruction/'):
                    os.makedirs(dir + 'reconstruction/')
                # VISUALIZATION: plot real images
                plot_images(args, data.data.cpu().numpy()[0:9], dir + 'reconstruction/', 'real', size_x=3, size_y=3)
            x_mean = model.reconstruct_x(x)
            plot_images(args, x_mean.data.cpu().numpy()[0:9], dir + 'reconstruction/', str(epoch), size_x=3, size_y=3)

    if mode == 'test':
        # load all data
        # grab the test data by iterating over the loader
        # there is no standardized tensor_dataset member across pytorch datasets
        test_data, test_target = [], []
        for data, lbls in data_loader:
            test_data.append(data)
            test_target.append(lbls)

        test_data, test_target = [torch.cat(test_data, 0), torch.cat(test_target, 0).squeeze()]

        # grab the train data by iterating over the loader
        # there is no standardized tensor_dataset member across pytorch datasets
        full_data = []
        for data, _ in train_loader:
            full_data.append(data)

        full_data = torch.cat(full_data, 0)

        if args.cuda:
            test_data, test_target, full_data = test_data.cuda(), test_target.cuda(), full_data.cuda()

        if args.dynamic_binarization:
            full_data = torch.bernoulli(full_data)

        # print(model.means(model.idle_input))

        # VISUALIZATION: plot real images
        plot_images(args, test_data.data.cpu().numpy()[0:25], dir, 'real', size_x=5, size_y=5)

        # VISUALIZATION: plot reconstructions
        samples = model.reconstruct_x(test_data[0:25])

        plot_images(args, samples.data.cpu().numpy(), dir, 'reconstructions', size_x=5, size_y=5)

        # VISUALIZATION: plot generations
        samples_rand = model.generate_x(25)

        plot_images(args, samples_rand.data.cpu().numpy(), dir, 'generations', size_x=5, size_y=5)

        if args.prior == 'vampprior':
            # VISUALIZE pseudoinputs
            pseudoinputs = model.means(model.idle_input).cpu().data.numpy()

            plot_images(args, pseudoinputs[0:25], dir, 'pseudoinputs', size_x=5, size_y=5)

        # CALCULATE lower-bound
        t_ll_s = time.time()
        elbo_test = model.calculate_lower_bound(test_data, MB=args.MB)
        t_ll_e = time.time()
        print('Test lower-bound value {:.2f} in time: {:.2f}s'.format(elbo_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        elbo_train = model.calculate_lower_bound(full_data, MB=args.MB)
        t_ll_e = time.time()
        print('Train lower-bound value {:.2f} in time: {:.2f}s'.format(elbo_train, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_test = model.calculate_likelihood(test_data, dir, mode='test', S=args.S, MB=args.MB)
        t_ll_e = time.time()
        print('Test log_likelihood value {:.2f} in time: {:.2f}s'.format(log_likelihood_test, t_ll_e - t_ll_s))

        # CALCULATE log-likelihood
        t_ll_s = time.time()
        log_likelihood_train = 0. #model.calculate_likelihood(full_data, dir, mode='train', S=args.S, MB=args.MB)) #commented because it takes too much time
        t_ll_e = time.time()
        print('Train log_likelihood value {:.2f} in time: {:.2f}s'.format(log_likelihood_train, t_ll_e - t_ll_s))

    # calculate final loss
    evaluate_loss /= len(data_loader)  # loss function already averages over batch size
    evaluate_re /= len(data_loader)  # re already averages over batch size
    evaluate_kl /= len(data_loader)  # kl already averages over batch size
    if mode == 'test':
        return evaluate_loss, evaluate_re, evaluate_kl, log_likelihood_test, log_likelihood_train, elbo_test, elbo_train
    else:
        return evaluate_loss, evaluate_re, evaluate_kl
コード例 #5
0
ファイル: evaluation.py プロジェクト: sunsunyyl/SWAE
def evaluate_wae(args, model, train_loader, data_loader, epoch, results_dir,
                 mode):
    # set loss to 0
    evaluate_loss = 0
    evaluate_recon = 0
    evaluate_z = 0
    evaluate_x = 0
    evaluate_mi = 0
    # set model to evaluation mode
    model.eval()

    if args.number_components_input >= 100:

        M = 10
        N = 5
    else:
        N = min(8, args.number_components_input)
        M = max(1, args.number_components_input // N)

    # no warmup assumed
    if args.warmup == 0:
        beta = args.beta

    # evaluate
    for batch_idx, (data, target) in enumerate(data_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()

        x = data

        # calculate loss function
        loss, recon_loss, z_loss, x_loss, mi, _, _, _, _, _, _ = model.calculate_loss(
            x, beta, average=True)

        evaluate_loss += loss.item()
        evaluate_recon += recon_loss.item()
        evaluate_z += z_loss.item()
        evaluate_x += x_loss.item()

        #print N digits
        if batch_idx == 0 and mode == 'validation':

            log_step = 50

            if epoch == 1:
                if not os.path.exists(results_dir + 'validation/'):
                    os.makedirs(results_dir + 'validation/')
                # VISUALIZATION: plot real images
                plot_images(args,
                            data.data.cpu().numpy()[0:N * M],
                            results_dir + 'validation/',
                            'real_x',
                            size_x=N,
                            size_y=M)

            if epoch % log_step == 0:
                x_recon = model.reconstruct_x(x)
                plot_images(args,
                            x_recon.data.cpu().numpy()[0:N * M],
                            results_dir + 'validation/',
                            'recon_x_epoch' + str(epoch),
                            size_x=N,
                            size_y=M)

            if args.prior == 'x_prior' and epoch % log_step == 0:
                # VISUALIZE pseudoinputs
                pseudoinputs = model.means(model.idle_input).cpu().data.numpy()

                plot_images(args,
                            pseudoinputs[0:N * M],
                            results_dir + 'validation/',
                            'pseu_x' + '_K' +
                            str(args.number_components_input) + '_L' +
                            str(args.number_components_latent) + '_epoch' +
                            str(epoch),
                            size_x=N,
                            size_y=M)

    if mode == 'test':
        # load all data
        # grab the test data by iterating over the loader
        # there is no standardized tensor_dataset member across pytorch datasets
        test_data, test_target = [], []
        for data, lbls in data_loader:
            test_data.append(data)
            test_target.append(lbls)

        test_data, test_target = [
            torch.cat(test_data, dim=0),
            torch.cat(test_target, dim=0).squeeze()
        ]

        #test noisy input
        for i in np.linspace(0.1, 0.5, num=5):
            test_noisy_data = test_data[0:N * M] + torch.randn(
                test_data[0:N * M].size()) * i
            if args.cuda:
                test_noisy_data = test_noisy_data.cuda()
            plot_images(args,
                        test_noisy_data.data.cpu().numpy(),
                        results_dir,
                        'real_noisy_x_' + str(round(i, 1)),
                        size_x=N,
                        size_y=M)

            noisy_samples = model.reconstruct_x(test_noisy_data)
            plot_images(args,
                        noisy_samples.data.cpu().numpy(),
                        results_dir,
                        'recon_noisy_x_' + str(round(i, 1)),
                        size_x=N,
                        size_y=M)

        if args.cuda:
            test_data, test_target = test_data.cuda(), test_target.cuda()

        if args.fid_score == True or args.fid_score_rec == True:
            evaluate_fid_score(args, model, results_dir, test_data)

        # VISUALIZATION: plot real images

        plot_images(args,
                    test_data.data.cpu().numpy()[0:N * M],
                    results_dir,
                    'real_x',
                    size_x=N,
                    size_y=M)

        # VISUALIZATION: plot reconstructions
        samples = model.reconstruct_x(test_data[0:N * M])

        plot_images(args,
                    samples.data.cpu().numpy(),
                    results_dir,
                    'recon_x',
                    size_x=N,
                    size_y=M)

        # VISUALIZATION: plot generations
        samples_rand = model.generate_x(N * M)
        plot_images(args,
                    samples_rand.data.cpu().numpy(),
                    results_dir,
                    'gen_x',
                    size_x=N,
                    size_y=M)

        if args.interpolation_test == True:
            samples_z, _, _, _ = model.generate_z_prior(10)

            interp_x = np.empty([0, np.prod(args.input_size)])
            print('samples_z', samples_z[0], samples_z[1])

            for i in np.arange(0, 1.1, 0.1):

                interp_z = (1 - i) * samples_z[0] + i * samples_z[1]

                recon_x = model.decoder(interp_z).data.cpu().numpy()
                recon_x = recon_x.reshape((1, np.prod(args.input_size)))

                interp_x = np.append(interp_x, recon_x, axis=0)
            plot_images(args,
                        interp_x,
                        results_dir,
                        'interpolate_z',
                        size_x=1,
                        size_y=len(np.arange(0, 1.1, 0.1)))

    # calculate final loss
    evaluate_loss /= len(
        data_loader)  # loss function already averages over batch size
    evaluate_recon /= len(
        data_loader)  # recon already averages over batch size
    evaluate_z /= len(data_loader)  # z already averages over batch size
    evaluate_x /= len(data_loader)

    return evaluate_loss, evaluate_recon, evaluate_z, evaluate_x, evaluate_mi