コード例 #1
0
ファイル: vae_tc.py プロジェクト: admshumar/mnist-vae
    def plot_results(self, models):
        """Plots labels and MNIST digits as a function of the 2D latent vector

        # Arguments
            models (tuple): encoder and decoder models
            data (tuple): test data and label
            batch_size (int): prediction batch size
            model_name (string): which model is using this function
        """
        encoder, decoder = models
        test_gaussian = operations.get_gaussian_parameters(self.x_test)
        os.makedirs(self.image_directory, exist_ok=True)
        filename = os.path.join(self.image_directory, "vae_mean.png")

        # display a 2D plot of the digit classes in the latent space
        z_gaussian, z_mnist = encoder.predict([test_gaussian, self.x_test],
                                              batch_size=self.batch_size)
        z_mean, z_covariance = operations.split_gaussian_parameters(z_gaussian)
        plt.figure(figsize=(12, 10))
        plt.scatter(z_mean[:, 0], z_mean[:, 1], c=self.y_test)
        plt.colorbar()
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.savefig(filename)
        if self.show:
            plt.show()

        filename = os.path.join(self.image_directory, "digits_over_latent.png")
        # display a 30x30 2D manifold of digits
        n = 30
        digit_size = 28
        figure = np.zeros((digit_size * n, digit_size * n))
        # linearly spaced coordinates corresponding to the 2D plot
        # of digit classes in the latent space
        grid_x = np.linspace(-4, 4, n)
        grid_y = np.linspace(-4.5, 3.5, n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                dummy_gaussian = np.array([[0, 0, 1, 1]])
                z_sample = np.array([[xi, yi]])
                x_decoded = decoder.predict([dummy_gaussian, z_sample])
                digit = x_decoded[1].reshape(digit_size, digit_size)
                figure[i * digit_size:(i + 1) * digit_size,
                       j * digit_size:(j + 1) * digit_size] = digit

        plt.figure(figsize=(10, 10))
        start_range = digit_size // 2
        end_range = (n - 1) * digit_size + start_range + 1
        pixel_range = np.arange(start_range, end_range, digit_size)
        sample_range_x = np.round(grid_x, 1)
        sample_range_y = np.round(grid_y, 1)
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.imshow(figure, cmap='Greys_r')
        plt.savefig(filename)
        if self.show:
            plt.show()
コード例 #2
0
    def __init__(self,
                 deep=True,
                 enable_activation=True,
                 enable_augmentation=False,
                 enable_batch_normalization=True,
                 enable_dropout=True,
                 enable_early_stopping=False,
                 early_stopping_patience=10,
                 enable_lr_reduction=False,
                 lr_reduction_patience=10,
                 enable_logging=True,
                 enable_manual_clusters=False,
                 enable_label_smoothing=False,
                 enable_rotations=False,
                 enable_stochastic_gradient_descent=False,
                 has_custom_layers=True,
                 has_validation_set=False,
                 validation_size=0.5,
                 is_mnist=True,
                 is_restricted=False,
                 is_standardized=False,
                 show=False,
                 with_mixture_model=False,
                 with_logistic_regression=False,
                 with_svc=False,
                 number_of_clusters=3,
                 restriction_labels=list(range(10)),
                 intermediate_dimension=512,
                 latent_dimension=2,
                 exponent_of_latent_space_dimension=1,
                 augmentation_size=100,
                 covariance_coefficient=0.2,
                 number_of_epochs=5,
                 batch_size=128,
                 learning_rate_initial=1e-5,
                 learning_rate_minimum=1e-6,
                 dropout_rate=0.5,
                 l2_constant=1e-4,
                 early_stopping_delta=0.01,
                 beta=1,
                 smoothing_alpha=0.5,
                 number_of_rotations=11,
                 angle_of_rotation=30,
                 encoder_activation='relu',
                 decoder_activation='relu',
                 encoder_activation_layer=LeakyReLU(),
                 decoder_activation_layer=LeakyReLU(),
                 final_activation='sigmoid',
                 model_name='vae'):

        self.model_name = model_name
        self.enable_logging = enable_logging
        self.enable_label_smoothing = enable_label_smoothing
        self.deep = deep
        self.is_mnist = is_mnist
        self.is_restricted = is_restricted
        self.restriction_labels = restriction_labels
        self.enable_early_stopping = enable_early_stopping and has_validation_set
        self.enable_rotations = enable_rotations
        self.number_of_rotations = number_of_rotations
        self.angle_of_rotation = angle_of_rotation
        self.with_mixture_model = with_mixture_model
        self.with_logistic_regression = with_logistic_regression
        self.with_svc = with_svc
        self.alpha = smoothing_alpha
        self.validation_size = validation_size

        self.is_standardized = is_standardized
        self.enable_stochastic_gradient_descent = enable_stochastic_gradient_descent
        self.has_custom_layers = has_custom_layers
        self.exponent_of_latent_space_dimension = exponent_of_latent_space_dimension
        self.enable_augmentation = enable_augmentation
        self.augmentation_size = augmentation_size
        self.covariance_coefficient = covariance_coefficient
        self.show = show
        self.restriction_labels = restriction_labels
        self.early_stopping_patience = early_stopping_patience
        self.enable_lr_reduction = enable_lr_reduction
        self.lr_reduction_patience = lr_reduction_patience

        self.has_validation_set = has_validation_set
        if is_mnist:
            if has_validation_set:
                x_train, y_train, x_val, y_val, x_test, y_test = VAE.get_split_mnist_data(
                )

                if is_restricted:
                    x_train, y_train = operations.restrict_data_by_label(
                        x_train, y_train, restriction_labels)
                    x_val, y_val = operations.restrict_data_by_label(
                        x_val, y_val, restriction_labels)
                    x_test, y_test = operations.restrict_data_by_label(
                        x_test, y_test, restriction_labels)
                    if enable_rotations:
                        print("Rotations enabled!")
                        x_train, y_train, \
                        x_val, y_val, \
                        x_test, y_test = VAE.get_split_rotated_mnist_data(restriction_labels,
                                                                          number_of_rotations,
                                                                          angle_of_rotation)
                else:
                    if enable_rotations:
                        print("Rotations enabled!")
                        x_train, y_train, \
                        x_val, y_val, \
                        x_test, y_test = VAE.get_split_rotated_mnist_data(list(range(10)),
                                                                          number_of_rotations,
                                                                          angle_of_rotation)

                self.x_train, self.y_train, self.x_val, self.y_val, self.x_test, self.y_test \
                    = x_train, y_train, x_val, y_val, x_test, y_test

                self.y_train_binary = OneHotEncoder(y_train).encode()
                self.y_val_binary = OneHotEncoder(y_val).encode()
                self.y_test_binary = OneHotEncoder(y_test).encode()

                if enable_label_smoothing:
                    self.y_train_smooth = labels.Smoother(
                        y_train, alpha=smoothing_alpha).smooth_uniform()
                    self.y_val_smooth = labels.Smoother(
                        y_val, alpha=smoothing_alpha).smooth_uniform()
                    self.y_test_smooth = labels.Smoother(
                        y_test, alpha=smoothing_alpha).smooth_uniform()

            else:
                (x_train, y_train), (x_test, y_test) = mnist.load_data()

                if is_restricted:
                    x_train, y_train = operations.restrict_data_by_label(
                        x_train, y_train, restriction_labels)
                    x_test, y_test = operations.restrict_data_by_label(
                        x_test, y_test, restriction_labels)

                if enable_rotations:
                    print("Rotations enabled!")
                    x_train = MNISTLoader('train').load(
                        restriction_labels, number_of_rotations,
                        angle_of_rotation)
                    y_train = MNISTLoader('train').load(restriction_labels,
                                                        number_of_rotations,
                                                        angle_of_rotation,
                                                        label=True)
                    x_train, y_train = VAE.shuffle(x_train, y_train)

                    x_test = MNISTLoader('test').load(restriction_labels,
                                                      number_of_rotations,
                                                      angle_of_rotation)
                    y_test = MNISTLoader('test').load(restriction_labels,
                                                      number_of_rotations,
                                                      angle_of_rotation,
                                                      label=True)
                    x_test, y_test = VAE.shuffle(x_test, y_test)

                self.x_train, self.y_train, self.x_test, self.y_test = x_train, y_train, x_test, y_test

                self.y_train_binary = OneHotEncoder(y_train).encode()
                self.y_test_binary = OneHotEncoder(y_test).encode()

                if enable_label_smoothing:
                    self.y_train_smooth = labels.Smoother(
                        y_train, alpha=smoothing_alpha).smooth_uniform()
                    self.y_test_smooth = labels.Smoother(
                        y_test, alpha=smoothing_alpha).smooth_uniform()

            if is_restricted:
                self.number_of_clusters = len(restriction_labels)
            else:
                self.number_of_clusters = len(np.unique(y_train))

            self.enable_manual_clusters = enable_manual_clusters
            if enable_manual_clusters:
                self.number_of_clusters = number_of_clusters

            self.data_width, self.data_height = self.x_train.shape[
                1], self.x_train.shape[2]
            self.data_dimension = self.data_width * self.data_height
            self.intermediate_dimension = intermediate_dimension

            self.x_train = operations.normalize(self.x_train)
            self.x_test = operations.normalize(self.x_test)
            if has_validation_set:
                self.x_val = operations.normalize(self.x_val)

            self.gaussian_train = operations.get_gaussian_parameters(
                self.x_train)
            self.gaussian_test = operations.get_gaussian_parameters(
                self.x_test)

        self.x_train_length = len(self.x_train)
        self.x_test_length = len(self.x_test)
        """
        Hyperparameters for the neural network.
        """
        self.number_of_epochs = number_of_epochs

        if self.enable_stochastic_gradient_descent:
            self.batch_size = batch_size
        else:
            self.batch_size = len(self.x_train)

        self.learning_rate = learning_rate_initial
        self.learning_rate_minimum = learning_rate_minimum
        self.enable_batch_normalization = enable_batch_normalization
        self.enable_dropout = enable_dropout
        self.enable_activation = enable_activation
        self.encoder_activation = encoder_activation  # 'relu', 'tanh', 'elu', 'softmax', 'sigmoid'
        self.decoder_activation = decoder_activation
        self.encoder_activation_layer = encoder_activation_layer
        self.decoder_activation_layer = decoder_activation_layer
        self.final_activation = final_activation
        self.dropout_rate = dropout_rate
        self.l2_constant = l2_constant
        self.early_stopping_delta = early_stopping_delta

        self.latent_dimension = latent_dimension
        self.gaussian_dimension = 2 * self.latent_dimension

        self.beta = max(beta, 1)

        self.hyper_parameter_list = [
            self.number_of_epochs, self.batch_size, self.learning_rate,
            self.encoder_activation, self.decoder_activation,
            self.enable_batch_normalization, self.enable_dropout,
            self.dropout_rate, self.l2_constant, self.early_stopping_patience,
            self.early_stopping_delta, self.latent_dimension
        ]

        if self.is_mnist:
            self.hyper_parameter_list.append("mnist")

        if self.is_restricted:
            restriction_string = ''
            for number in restriction_labels:
                restriction_string += str(number) + ','
            self.hyper_parameter_list.append(
                f"restricted_{restriction_string[:-1]}")

        if self.enable_augmentation:
            augmentation_string = "_".join([
                "augmented",
                str(covariance_coefficient),
                str(augmentation_size)
            ])
            self.hyper_parameter_list.append(augmentation_string)

        if not self.enable_activation:
            self.hyper_parameter_list.append("PCA")

        if self.enable_rotations:
            self.hyper_parameter_list.append(
                f"rotated_{number_of_rotations},{angle_of_rotation}")

        if beta > 1:
            self.hyper_parameter_list.append(f"beta_{beta}")

        if smoothing_alpha > 1:
            self.hyper_parameter_list.append(f"alpha_{smoothing_alpha}")

        self.hyper_parameter_string = '_'.join(
            [str(i) for i in self.hyper_parameter_list])

        self.directory_counter = directories.DirectoryCounter(
            self.hyper_parameter_string)
        self.directory_number = self.directory_counter.count()
        self.hyper_parameter_string = '_'.join([
            self.hyper_parameter_string,
            'x{:02d}'.format(self.directory_number)
        ])

        directory, image_directory = directories.DirectoryCounter.make_output_directory(
            self.hyper_parameter_string, self.model_name)
        self.experiment_directory = directory
        self.image_directory = image_directory
        """
        Tensorflow Input instances for declaring model inputs.
        """
        self.mnist_shape = self.x_train.shape[1:]
        self.gaussian_shape = 2 * self.latent_dimension
        self.encoder_gaussian = Input(shape=self.gaussian_shape,
                                      name='enc_gaussian')
        self.encoder_mnist_input = Input(shape=self.mnist_shape,
                                         name='enc_mnist')
        self.auto_encoder_gaussian = Input(shape=self.gaussian_shape,
                                           name='ae_gaussian')
        self.auto_encoder_mnist_input = Input(shape=self.mnist_shape,
                                              name='ae_mnist')
        """
        Callbacks to TensorBoard for observing the model structure and network training curves.
        """
        self.tensorboard_callback = TensorBoard(log_dir=os.path.join(
            self.experiment_directory, 'tensorboard_logs'),
                                                histogram_freq=1,
                                                write_graph=False,
                                                write_images=True)

        self.early_stopping_callback = EarlyStopping(
            monitor='val_loss',
            min_delta=self.early_stopping_delta,
            patience=self.early_stopping_patience,
            mode='auto',
            restore_best_weights=True)

        self.learning_rate_callback = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.1,
            patience=self.lr_reduction_patience,
            min_lr=self.learning_rate_minimum)

        self.nan_termination_callback = TerminateOnNaN()

        self.colors = ['#00B7BA', '#FFB86F', '#5E6572', '#6B0504', '#BA5C12']
コード例 #3
0
ファイル: vae_tc.py プロジェクト: admshumar/mnist-vae
    def __init__(
            self,
            deep=True,
            is_mnist=True,
            number_of_clusters=3,
            is_restricted=False,
            is_standardized=False,
            restriction_labels=[1, 2, 3],
            intermediate_dimension=512,
            enable_stochastic_gradient_descent=False,
            has_custom_layers=True,
            has_validation_set=False,
            exponent_of_latent_space_dimension=1,
            enable_augmentation=False,
            augmentation_size=100,
            covariance_coefficient=0.2,
            show=False,
            number_of_epochs=5,
            batch_size=128,
            learning_rate_initial=1e-5,
            learning_rate_minimum=1e-6,
            enable_batch_normalization=True,
            enable_dropout=True,
            enable_activation=True,
            encoder_activation='relu',  # 'relu', 'tanh', 'elu', 'softmax', 'sigmoid'
            decoder_activation='relu',
            final_activation='sigmoid',
            dropout_rate=0.2,
            l2_constant=1e-4,
            early_stopping_delta=1,
            beta=1,
            enable_logging=True):
        """
        For an MNIST variational autoencoder, we have the usual options that control network hyperparameters. In
        addition, . . .
        :param deep: A boolean indicating whether the autoencoder has more than one hidden layer.
        :param is_mnist: A boolean indicating whether the data set is MNIST.
        :param number_of_clusters: An integer indicating the number of clusters to be produced by clustering algorithms.
        :param is_restricted: A boolean indicating whether at least one class label is to be ignored.
        :param restriction_labels: A list of integers that indicate the class labels to be retained in the data set.
        :param is_standardized: A boolean indicating whether the train_contrastive_mlp and test sets are standardized before being
            input into the network.
        :param enable_stochastic_gradient_descent: A boolean indicating whether SGD is performed during training.
        :param has_custom_layers: A boolean indicating the layer structure of the network.
        :param exponent_of_latent_space_dimension: An integer indicating the size of the latent space.
        :param enable_augmentation: A boolean indicating whether data augmentation is to be performed.
        :param augmentation_size: An integer indicating how much data are to be sampled for each existing data point.
        :param covariance_coefficient: A float indicating the scalar multiple of the identity covariance matrix for the
            Gaussians that are used to augment the data.
        :param show: A boolean indicating whether matplotlib.pyplot.show is invoked after inference.
            By default this is False.
        :param number_of_epochs: An integer indicating the number of training epochs.
        :param batch_size: An integer indicating the batch size.
        :param learning_rate_initial: A float indicating the initial learning rate.
        :param learning_rate_minimum: A float indicating the minimum learning rate (for a learning rate scheduler).
        :param enable_batch_normalization: A boolean indicating whether batch normalization is performed.
        :param enable_dropout: A boolean indicating whether dropout is performed during training.
        :param enable_activation: A boolean indicating whether activation functions are used during training. In the
            case of an autoencoder, removing network activations will give us an algorithm similar to PCA.
        :param encoder_activation: A boolean indicating the activation function to be used in the encoder layers.
        :param decoder_activation: A boolean indicating the activation function to be used in the decoder layers.
        :param dropout_rate: A float indicating the proportion of neurons to be deactivated.
        :param l2_constant: A float indicating the amount of L2 regularization.
        :param early_stopping_delta: A float indicating the number of epochs before training is halted due to an
            insufficient change in the validation loss.
        :param beta: A float indicating the beta hyperparameter for a beta-variational autoencoder. Default is 0.
        """
        self.model_name = "vae_tc"
        self.enable_logging = enable_logging
        self.deep = deep
        self.is_mnist = is_mnist
        self.is_restricted = is_restricted
        self.restriction_labels = restriction_labels

        if self.is_restricted:
            self.number_of_clusters = len(self.restriction_labels)
        else:
            self.number_of_clusters = number_of_clusters

        self.is_standardized = is_standardized
        self.enable_stochastic_gradient_descent = enable_stochastic_gradient_descent
        self.has_custom_layers = has_custom_layers
        self.exponent_of_latent_space_dimension = exponent_of_latent_space_dimension
        self.enable_augmentation = enable_augmentation
        self.augmentation_size = augmentation_size
        self.covariance_coefficient = covariance_coefficient
        self.show = show
        self.restriction_labels = restriction_labels

        self.has_validation_set = has_validation_set
        if self.is_mnist:
            if self.has_validation_set:
                self.x_train, self.y_train, \
                self.x_val, self.y_val, \
                self.x_test, self.y_test = TCVAE.get_split_mnist_data()
            else:
                (self.x_train,
                 self.y_train), (self.x_test, self.y_test) = mnist.load_data()

            self.data_width, self.data_height = self.x_train.shape[
                1], self.x_train.shape[2]
            self.data_dimension = self.data_width * self.data_height
            self.intermediate_dimension = intermediate_dimension

            self.x_train = operations.normalize(self.x_train)
            self.x_val = operations.normalize(self.x_val)
            self.x_test = operations.normalize(self.x_test)

            self.gaussian_train = operations.get_gaussian_parameters(
                self.x_train)
            self.gaussian_val = operations.get_gaussian_parameters(self.x_test)

        self.x_train_length = len(self.x_train)
        self.x_test_length = len(self.x_test)
        """
        Hyperparameters for the neural network.
        """
        self.number_of_epochs = number_of_epochs

        if self.enable_stochastic_gradient_descent:
            self.batch_size = batch_size
        else:
            self.batch_size = len(self.x_train)

        self.learning_rate = learning_rate_initial
        self.learning_rate_minimum = learning_rate_minimum
        self.enable_batch_normalization = enable_batch_normalization
        self.enable_dropout = enable_dropout
        self.enable_activation = enable_activation
        self.encoder_activation = encoder_activation  # 'relu', 'tanh', 'elu', 'softmax', 'sigmoid'
        self.decoder_activation = decoder_activation
        self.final_activation = final_activation
        self.dropout_rate = dropout_rate
        self.l2_constant = l2_constant
        self.patience_limit = self.number_of_epochs // 5
        self.early_stopping_delta = early_stopping_delta

        self.latent_dim = 2
        self.gaussian_dimension = 2 * self.latent_dim

        self.beta = max(beta, 1)
        if self.beta > 1:
            self.enable_beta = True
        else:
            self.enable_beta = False

        self.hyper_parameter_list = [
            self.number_of_epochs, self.batch_size, self.learning_rate,
            self.encoder_activation, self.decoder_activation,
            self.enable_batch_normalization, self.enable_dropout,
            self.dropout_rate, self.enable_beta, self.beta, self.l2_constant,
            self.patience_limit, self.early_stopping_delta, self.latent_dim
        ]

        if self.is_mnist:
            self.hyper_parameter_list.append("mnist")

        if self.is_restricted:
            restriction_label_string = ''
            for label in restriction_labels:
                restriction_label_string += str(label)
                self.hyper_parameter_list.append(
                    "restricted_{}".format(restriction_label_string))

        if self.enable_augmentation:
            augmentation_string = "_".join([
                "augmented",
                str(covariance_coefficient),
                str(augmentation_size)
            ])
            self.hyper_parameter_list.append(augmentation_string)

        if not self.enable_activation:
            self.hyper_parameter_list.append("PCA")

        self.hyper_parameter_string = '_'.join(
            [str(i) for i in self.hyper_parameter_list])

        self.directory_counter = directories.DirectoryCounter(
            self.hyper_parameter_string)
        self.directory_number = self.directory_counter.count()
        self.hyper_parameter_string = '_'.join([
            self.hyper_parameter_string,
            'x{:02d}'.format(self.directory_number)
        ])
        self.directory = directories.DirectoryCounter.make_output_directory(
            self.hyper_parameter_string, self.model_name)
        self.image_directory = os.path.join('images', self.directory)
        """
        Tensorflow Input instances for declaring model inputs.
        """
        self.mnist_shape = self.x_train.shape[1:]
        self.gaussian_shape = 2 * self.latent_dim
        self.encoder_gaussian = Input(shape=self.gaussian_shape,
                                      name='enc_gaussian')
        self.encoder_mnist_input = Input(shape=self.mnist_shape,
                                         name='enc_mnist')
        self.auto_encoder_gaussian = Input(shape=self.gaussian_shape,
                                           name='ae_gaussian')
        self.auto_encoder_mnist_input = Input(shape=self.mnist_shape,
                                              name='ae_mnist')
        """
        Callbacks to TensorBoard for observing the model structure and network training curves.
        """
        self.tensorboard_callback = TensorBoard(log_dir=os.path.join(
            self.directory, 'tensorboard_logs'),
                                                histogram_freq=2,
                                                write_graph=True,
                                                write_images=True)

        self.early_stopping_callback = EarlyStopping(
            monitor='val_loss',
            min_delta=self.early_stopping_delta,
            patience=self.patience_limit,
            mode='auto',
            restore_best_weights=True)

        self.learning_rate_callback = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.1,
            patience=50,
            min_lr=self.learning_rate_minimum)

        self.nan_termination_callback = TerminateOnNaN()

        self.colors = ['#00B7BA', '#FFB86F', '#5E6572', '#6B0504', '#BA5C12']
コード例 #4
0
ファイル: vae.py プロジェクト: admshumar/ct-vae
    def plot_results(self, models):
        """Plots labels and MNIST digits as a function of the 2D latent vector

        # Arguments
            models (tuple): encoder and decoder models
            data (tuple): test data and label
            batch_size (int): prediction batch size
            model_name (string): which model is using this function
        """
        encoder, decoder = models
        test_gaussian = operations.get_gaussian_parameters(
            self.x_test, self.latent_dimension)
        os.makedirs(self.image_directory, exist_ok=True)

        filename = "vae_mean.png"
        filepath = os.path.join(self.image_directory, filename)

        z_gaussian, z_data = encoder.predict([test_gaussian, self.x_test],
                                             batch_size=self.batch_size)
        z_mean, z_covariance = operations.split_gaussian_parameters(z_gaussian)

        if self.latent_dimension == 2:
            # display a 2D plot of the data classes in the latent space
            plt.figure(figsize=(12, 10))
            plt.scatter(z_mean[:, 0],
                        z_mean[:, 1],
                        c=self.y_test,
                        s=8,
                        alpha=0.3)
            plt.colorbar(ticks=np.linspace(0, 2, 3))
            plt.xlabel("z[0]")
            plt.ylabel("z[1]")
            plt.savefig(filepath, dpi=200)
            if self.show:
                plt.show()
        else:
            # display a 2D t-SNE of the data classes in the latent space
            plt.figure(figsize=(12, 10))
            tsne = LatentSpaceTSNE(z_mean, self.y_test,
                                   self.experiment_directory)
            tsne.save_tsne()

        if self.latent_dimension == 2:
            if self.is_mnist:
                filename = "latent.png"
                filepath = os.path.join(self.image_directory, filename)
                # display a 30x30 2D manifold of digits
                n = 30
                image_size = 28
                figure = np.zeros((image_size * n, image_size * n))
                # linearly spaced coordinates corresponding to the 2D plot
                # of digit classes in the latent space
                grid_x = np.linspace(-4, 4, n)
                grid_y = np.linspace(-4.5, 3.5, n)[::-1]

                for i, yi in enumerate(grid_y):
                    for j, xi in enumerate(grid_x):
                        parameter_tuple = (np.zeros(self.latent_dimension),
                                           np.ones(self.latent_dimension))
                        dummy_gaussian = np.asarray(
                            [np.concatenate(parameter_tuple)])
                        z_sample = np.array([[xi, yi]])
                        x_decoded = decoder.predict([dummy_gaussian, z_sample])
                        digit = x_decoded[1].reshape(image_size, image_size)
                        figure[i * image_size:(i + 1) * image_size,
                               j * image_size:(j + 1) * image_size] = digit

                plt.figure(figsize=(10, 10))
                start_range = image_size // 2
                end_range = (n - 1) * image_size + start_range + 1
                pixel_range = np.arange(start_range, end_range, image_size)
                sample_range_x = np.round(grid_x, 1)
                sample_range_y = np.round(grid_y, 1)
                plt.xticks(pixel_range, sample_range_x)
                plt.yticks(pixel_range, sample_range_y)
                plt.xlabel("z[0]")
                plt.ylabel("z[1]")
                plt.imshow(figure, cmap='Greys_r')
                plt.savefig(filepath)
                if self.show:
                    plt.show()
                plt.close('all')

            else:
                filename = "latent.png"
                filepath = os.path.join(self.image_directory, filename)
                # display a latent representation
                n = 30
                image_size = 224
                figure = np.zeros((image_size * n, image_size * n))
                # linearly spaced coordinates corresponding to the 2D plot
                # of digit classes in the latent space
                grid_x = np.linspace(-4, 4, n)
                grid_y = np.linspace(-4.5, 3.5, n)[::-1]

                for i, yi in enumerate(grid_y):
                    for j, xi in enumerate(grid_x):
                        parameter_tuple = (np.zeros(self.latent_dimension),
                                           np.ones(self.latent_dimension))
                        dummy_gaussian = np.asarray(
                            [np.concatenate(parameter_tuple)])
                        z_sample = np.array([[xi, yi]])
                        x_decoded = decoder.predict([dummy_gaussian, z_sample])
                        digit = x_decoded[1].reshape(image_size, image_size)
                        figure[i * image_size:(i + 1) * image_size,
                               j * image_size:(j + 1) * image_size] = digit

                plt.figure(figsize=(10, 10))
                start_range = image_size // 2
                end_range = (n - 1) * image_size + start_range + 1
                pixel_range = np.arange(start_range, end_range, image_size)
                sample_range_x = np.round(grid_x, 1)
                sample_range_y = np.round(grid_y, 1)
                plt.xticks(pixel_range, sample_range_x)
                plt.yticks(pixel_range, sample_range_y)
                plt.xlabel("z[0]")
                plt.ylabel("z[1]")
                plt.imshow(figure, cmap='Greys_r')
                plt.savefig(filepath)
                if self.show:
                    plt.show()
                plt.close('all')