예제 #1
0
파일: show.py 프로젝트: zuxinrui/DCGAN
def _save_img_list(img_list, save_path, config):
    #_show_img_list(img_list)
    metadata = dict(title='generator images', artist='Matplotlib', comment='Movie support!')
    writer = ImageMagickWriter(fps=1,metadata=metadata)
    ims = [np.transpose(i, (1, 2, 0)) for i in img_list]
    fig, ax = plt.subplots()
    with writer.saving(fig, "%s/img_list.gif" % save_path,500):
        for i in range(len(ims)):
            ax.imshow(ims[i])
            ax.set_title("step {}".format(i * config["save_every"]))
            writer.grab_frame()
예제 #2
0
class Visualization(object):
    """Helper class to visualize the progress of the GAN training procedure.
    """
    def __init__(self, file_name, model_name, fps=15):
        """Initialize the helper class.
        
        :param fps: The number of frames per second when saving the gif animation.
        """
        self.fps = fps
        self.figure, (self.ax2) = plt.subplots(1, 1, figsize=(5, 5))
        self.figure.suptitle("{}".format(model_name))
        sns.set(color_codes=True, style='white', palette='colorblind')
        sns.despine(self.figure)
        plt.show(block=False)
        self.real_data = Dataset()
        self.step = 0
        self.writer = ImageMagickWriter(fps=self.fps)
        self.writer.setup(self.figure, file_name, dpi=100)

    def plot_progress(self, gen_net):
        """Plot the progress of the training procedure. This can be called back from the GAN fit method.
        
        :param gan: The GAN we are fitting.
        :param session: The current session of the GAN.
        :param data: The data object from which we are sampling the input data.
        """

        real = self.real_data.next_batch(batch_size=10000)
        r1, r2 = self.ax2.get_xlim()
        x = np.linspace(r1, r2, 10000)[:, np.newaxis]
        g = gen_net(torch.FloatTensor(np.random.randn(10000, 1)))

        self.ax2.clear()
        self.ax2.set_ylim([0, 1])
        self.ax2.set_xlim([0, 8])
        sns.kdeplot(real.numpy().flatten(),
                    shade=True,
                    ax=self.ax2,
                    label='Real data')
        sns.kdeplot(g.detach().numpy().flatten(),
                    shade=True,
                    ax=self.ax2,
                    label='Generated data')
        self.ax2.set_title('Distributions')
        self.ax2.set_title('{} iterations'.format(self.step * 50))
        self.ax2.set_xlabel('Input domain')
        self.ax2.set_ylabel('Probability density')
        self.ax2.legend(loc='upper left', frameon=True)
        self.figure.canvas.draw()
        self.figure.canvas.flush_events()
        self.writer.grab_frame()
        self.step += 1
예제 #3
0
    def save_animation(self, fps=5, time=5, *, filename="robot_animation.gif"):
        original_position = self.theta
        number_of_frames = self.path.shape[1]

        frames_to_animate = fps * time
        if number_of_frames < frames_to_animate:
            step = 1
            fps = max(1, time // number_of_frames)
        else:
            step = number_of_frames // frames_to_animate

        fig = plt.figure()
        robot_animation = ImageMagickWriter(fps=fps)

        with robot_animation.saving(fig, filename, dpi=150):
            for column in np.arange(start=0, stop=number_of_frames, step=step):
                self.theta = np.array(self.path[:, column])
                self.plot(show=False)
                robot_animation.grab_frame()
                fig.clear()
                self._set_plot_options()

        self.theta = original_position  # ??
        plt.close()
예제 #4
0
파일: gan_toy1d.py 프로젝트: zcrwind/gaan
class Visualization(object):
    """Helper class to visualize the progress of the GAN training procedure.
    """
    def __init__(self, save_animation=False, fps=30):
        """Initialize the helper class.

        :param save_animation: Whether the animation should be saved as a gif. Requires the ImageMagick library.
        :param fps: The number of frames per second when saving the gif animation.
        """
        self.save_animation = save_animation
        self.fps = fps
        self.figure, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(8, 4))
        self.figure.suptitle("1D GAAN")
        sns.set(color_codes=True, style='white', palette='colorblind')
        sns.despine(self.figure)
        plt.show(block=False)

        if self.save_animation:
            self.writer = ImageMagickWriter(fps=self.fps)
            self.writer.setup(self.figure, 'Toy_1D_GAAN.gif', dpi=100)

    def plot_progress(self, gan, session, data):
        """Plot the progress of the training procedure. This can be called back from the GAN fit method.

        :param gan: The GAN we are fitting.
        :param session: The current session of the GAN.
        :param data: The data object from which we are sampling the input data.
        """

        # Plot the training curve.
        steps = gan.log_interval * np.arange(len(gan.loss_d_curve))
        self.ax1.clear()
        self.ax1.plot(steps, gan.loss_d_curve, label='D loss')
        self.ax1.plot(steps, gan.loss_g_curve, label='G loss')

        self.ax1.set_title('Learning curve')
        self.ax1.set_xlabel('Iteration')

        if gan.model == 'gan' or gan.model == 'rgan' or gan.model == 'mdgan' or gan.model == 'vaegan' or gan.model == 'gaan':
            title = 'Loss'
        elif gan.model == 'wgangp':
            title = 'Negative critic loss'
        self.ax1.set_ylabel(title)

        # Plot the generated and the input data distributions.
        g = gan.sample(session)
        r1, r2 = self.ax2.get_xlim()
        x = np.linspace(r1, r2, gan.n_sample)[:, np.newaxis]

        critic = gan.dreal(session, x)

        if gan.model == 'wgangp':
            # Normalize the critic to be in [0, 1] to make visualization easier.
            critic = (critic - critic.min()) / (critic.max() - critic.min())
        d, _ = data.next_batch(gan.n_sample)
        if gan.model == 'gaan' or gan.model == 'rgan' or gan.model == 'mdgan':
            r = gan.reconstruct(session, d)
            e = gan.encode(session, d)

        self.ax2.clear()
        self.ax2.set_ylim([0, 1])
        self.ax2.set_xlim([-8, 8])
        if gan.model == 'gan' or gan.model == 'rgan' or gan.model == 'mdgan' or gan.model == 'gaan' or gan.model == 'vaegan':
            self.ax2.plot(x, critic, label='Decision boundary')
        elif gan.model == 'wgangp':
            self.ax2.plot(x, critic, label='Critic (normalized)')
        sns.kdeplot(d.flatten(), shade=True, ax=self.ax2, label='Real data')

        sns.kdeplot(g.flatten(),
                    shade=True,
                    ax=self.ax2,
                    label='Generated data')
        self.ax2.set_title('Distributions')
        self.ax2.set_xlabel('Input domain')
        self.ax2.set_ylabel('Probability density')
        self.ax2.legend(loc='upper left', frameon=True)

        if len(steps) - 1 == gan.n_step // gan.log_interval:
            if self.save_animation:
                wait_seconds = 3
                [
                    self.writer.grab_frame()
                    for _ in range(wait_seconds * self.fps)
                ]
                self.writer.finish()
            plt.show()
        else:
            self.figure.canvas.draw()
            self.figure.canvas.flush_events()
            if self.save_animation:
                self.writer.grab_frame()
예제 #5
0
                    delta_w = learning_rate * h[j] * (x - weights[j])
                    weights[j] += delta_w

            for j in range(len(index_map)):
                h_grid[int(j % grid_x1), int(j / grid_x1)] = h[j]

            scat3.remove()
            scat3 = ax2.scatter(index_map[:, 0],
                                index_map[:, 1],
                                s=50,
                                c=create_color_list(neuron_class))
            upscaled_image = Image.fromarray(h_grid).resize(
                [100, 100], resample=Image.LANCZOS)
            upscaled_image = np.asarray(upscaled_image)
            im.set_data(upscaled_image)

            scat.remove()
            scat = ax.scatter(weights[:, 0],
                              weights[:, 1],
                              weights[:, 2],
                              s=10,
                              c=create_color_list(neuron_class),
                              marker='^')

            plt.pause(0.1)
            writer.grab_frame()
            if (np.linalg.norm(weights - old_weights).mean() < 0.01):
                break
except KeyboardInterrupt:
    pass
writer.finish()
예제 #6
0
class Visualization(object):
    """Helper class to visualize the progress of the GAN training procedure.
    """
    def __init__(self, save_animation=False, fps=30):
        """Initialize the helper class.

        :param save_animation: Whether the animation should be saved as a gif. Requires the ImageMagick library.
        :param fps: The number of frames per second when saving the gif animation.
        """
        self.save_animation = save_animation
        self.fps = fps
        self.figure, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(8, 4))
        self.figure.suptitle("1D GAAN")
        sns.set(color_codes=True, style='white', palette='colorblind')
        sns.despine(self.figure)
        plt.show(block=False)

        if self.save_animation:
            self.writer = ImageMagickWriter(fps=self.fps)
            self.writer.setup(self.figure, 'Toy_1D_GAAN.gif', dpi=100)

    def plot_progress(self, gan, session, data):
        """Plot the progress of the training procedure. This can be called back from the GAN fit method.

        :param gan: The GAN we are fitting.
        :param session: The current session of the GAN.
        :param data: The data object from which we are sampling the input data.
        """

        # Plot the training curve.
        steps = gan.log_interval * np.arange(len(gan.loss_d_curve))
        self.ax1.clear()
        self.ax1.plot(steps, gan.loss_d_curve, label='D loss')
        self.ax1.plot(steps, gan.loss_g_curve, label='G loss')
        
        self.ax1.set_title('Learning curve')
        self.ax1.set_xlabel('Iteration')

        if gan.model == 'gan' or gan.model == 'rgan' or gan.model == 'mdgan' or gan.model == 'vaegan' or gan.model == 'gaan':
            title = 'Loss'
        elif gan.model == 'wgangp':
            title = 'Negative critic loss'
        self.ax1.set_ylabel(title)

        # Plot the generated and the input data distributions.
        g = gan.sample(session)
        r1,r2 = self.ax2.get_xlim()
        x = np.linspace(r1, r2, gan.n_sample)[:, np.newaxis]

        critic = gan.dreal(session, x)

        if gan.model == 'wgangp':
            # Normalize the critic to be in [0, 1] to make visualization easier.
            critic = (critic - critic.min()) / (critic.max() - critic.min())
        d, _ = data.next_batch(gan.n_sample)
        if gan.model == 'gaan' or gan.model == 'rgan' or gan.model == 'mdgan':
            r    = gan.reconstruct(session, d)
            e    = gan.encode(session, d)
            
        self.ax2.clear()
        self.ax2.set_ylim([0, 1])
        self.ax2.set_xlim([-8, 8])
        if gan.model == 'gan' or gan.model == 'rgan' or gan.model == 'mdgan' or gan.model == 'gaan' or gan.model == 'vaegan':
            self.ax2.plot(x, critic, label='Decision boundary')
        elif gan.model == 'wgangp':
            self.ax2.plot(x, critic, label='Critic (normalized)')
        sns.kdeplot(d.flatten(), shade=True, ax=self.ax2, label='Real data')
       
        sns.kdeplot(g.flatten(), shade=True, ax=self.ax2, label='Generated data')
        self.ax2.set_title('Distributions')
        self.ax2.set_xlabel('Input domain')
        self.ax2.set_ylabel('Probability density')
        self.ax2.legend(loc='upper left', frameon=True)

        if len(steps) - 1 == gan.n_step // gan.log_interval:
            if self.save_animation:
                wait_seconds = 3
                [self.writer.grab_frame() for _ in range(wait_seconds * self.fps)]
                self.writer.finish()
            plt.show()
        else:
            self.figure.canvas.draw()
            self.figure.canvas.flush_events()
            if self.save_animation:
                self.writer.grab_frame()