コード例 #1
0
    def plot_reconstructions(self,
                             X_test,
                             channel_to_plot=0,
                             nb_to_plot=10,
                             plot_input=True,
                             same_scale_as_input=True):
        """
        Plots the original images, as well as their reconstructions by the autoencoder.

        Arguments:
            X_test: a numpy array of shape (nb_samples, nb_rows, nb_columns, nb_input_channels)
                    Only the 10 first samples will be plotted.
            channel_to_plot: int, the output_channel that will be plotted.
        """
        shape_X = X_test.shape
        if len(shape_X) == 3:
            X_test = X_test.reshape((1, ) + shape_X)
        X_rec = self.reconstruction(X_test[:nb_to_plot])
        if same_scale_as_input:
            try:
                v_min = np.min(X_test[:nb_to_plot, :, :, channel_to_plot])
                v_max = np.max(X_test[:nb_to_plot, :, :, channel_to_plot])
            except IndexError:
                v_min = np.min(X_test[:nb_to_plot, :, :, 0])
                v_max = np.max(X_test[:nb_to_plot, :, :, 0])
        else:
            v_min = None
            v_max = None
        if plot_input:
            bastien_utils.plot_all_images(X_test[:nb_to_plot],
                                          channel_to_plot=channel_to_plot)
        bastien_utils.plot_all_images(X_rec,
                                      channel_to_plot=channel_to_plot,
                                      v_min=v_min,
                                      v_max=v_max)
コード例 #2
0
def plot_binarized_weighted_atoms(atoms,
                                  h_test_image,
                                  nb_atoms_to_use_for_threshold_computation=20
                                  ):
    bin_weighted_atoms = binarized_weighted_atoms(
        atoms,
        h_test_image,
        nb_atoms_to_use_for_threshold_computation=
        nb_atoms_to_use_for_threshold_computation)
    bastien_utils.plot_all_images(bin_weighted_atoms,
                                  same_intensity_scale=True)
コード例 #3
0
def plot_10_most_used_atoms_for_an_image(h_image,
                                         atoms,
                                         same_intensity_scale=True):
    """
    Arguments:
        h_image: numpy array of shape (N_features,), the encoding of an image
        atoms: numpy array of shape (N_features, N_pixels), the atom images of the learned representation
    """
    major_atoms = k_most_used_atoms(10, h_image, atoms)
    print(
        'Atoms associated with the 10 highest code coefficients of the image')
    bastien_utils.plot_all_images(major_atoms,
                                  same_intensity_scale=same_intensity_scale)
コード例 #4
0
 def plot_atoms_decoder(self,
                        channel_to_plot=0,
                        nb_to_plot=-1,
                        add_bias=False,
                        normalize=False):
     """
     Plot the weights of the decoder.
     Arguments:
         nb_to_plot: number of basis images to plot, -1 is all, otherwise plot the nb_to_plot first ones.
         channel: channel to plot (there are nb_input_channels*nb_atoms atoms)
         add_bias: bool, whether to add the bias (784,) to the weights.
         normalize: bool. If True each image is normalized, giving the artificial input images that maximize each of the code coefficients (with unity energy). 
     """
     atoms = self.atom_images_decoder(add_bias=add_bias,
                                      normalize=normalize)
     if (nb_to_plot < 0):
         n_atoms = self.latent_dim
     else:
         n_atoms = nb_to_plot
     bastien_utils.plot_all_images(atoms[:n_atoms],
                                   channel_to_plot=channel_to_plot,
                                   same_intensity_scale=True)
コード例 #5
0
def plot_weighted_atoms(atoms, h_test_image):
    weighted_atoms = atoms_weighted_by_encoding_coefficients(
        atoms, h_test_image)
    bastien_utils.plot_all_images(weighted_atoms, same_intensity_scale=True)