class SiameseGAN:
    def __init__(self, data_shape):
        self.data_shape = data_shape
        self.discriminator = None
        self.generator = None
        self.adversarial = None

        self.define_gan()
        self.noisy_samples = NoiseMaker(generator=self.generator)

        self.performance_output_path = 'performance/temp/'
        if not os.path.exists(self.performance_output_path):
            os.makedirs(self.performance_output_path)

    def define_gan(self):
        self.generator = build_generator(input_shape=self.data_shape)
        self.discriminator = build_discriminator(input_shape=self.data_shape)

        self.adversarial = build_adversarial(
            generator_model=self.generator,
            discriminator_model=self.discriminator)

    def train(self, dataset, epochs=100, batch_size=64):

        for e in range(epochs):
            print('Epochs: %3d/%d' % (e, epochs))
            self.single_epoch(dataset, batch_size)
            self.performance(step=e, test_data=dataset.test_data)

    def single_epoch(self, dataset, batch_size):
        half_batch_size = int(batch_size / 2)
        trained_samples = 0

        for realX, _, realY in dataset.iter_samples(half_batch_size):
            Y = np.ones(shape=(len(realX), ))

            fakeX, fakeY, noiseX = self.noisy_samples.denoise_samples(
                real_samples=realX)

            Y = np.zeros(shape=(len(realX), ))
            discriminator_loss = self.discriminator.train_on_batch(
                [realX, fakeX, noiseX], Y)

            noisy_input = self.noisy_samples.add_noise(realX)
            act_real = np.ones(shape=(len(noisy_input), ))

            gan_loss = self.adversarial.train_on_batch(
                [realX, noisy_input, noisy_input], act_real)

            trained_samples = min(trained_samples + half_batch_size,
                                  dataset.sample_number)
            print('     %5d/%d -> Discriminator Loss: %f, Gan Loss: %f' %
                  (trained_samples, dataset.sample_number, discriminator_loss,
                   gan_loss))

    def performance(self, step, test_data):
        # prepare fake examples
        generated, _, _ = self.noisy_samples.denoise_samples(
            real_samples=test_data)
        # scale from [-1,1] to [0,1]
        generated = (generated + 1) / 2.0
        # plot images
        for i in range(100):
            # define subplot
            pyplot.subplot(10, 10, 1 + i)
            # turn off axis
            pyplot.axis('off')
            # plot raw pixel data
            pyplot.imshow(generated[i, :, :, 0], cmap='gray_r')
        # save plot to file
        fig_file = self.performance_output_path + 'generated_plot_%04d.png' % (
            step + 1)
        pyplot.savefig(fig_file)
        pyplot.close()
        # save the generator model
        model_file = self.performance_output_path + 'model_%04d.h5' % (step +
                                                                       1)
        self.generator.save(model_file)
        print('>Saved: %s and %s' % (fig_file, model_file))
Beispiel #2
0
class Img2ImgGAN:
    def __init__(self, data_shape):
        self.data_shape = data_shape
        self.discriminator = None
        self.generator = None
        self.adversarial = None

        self.define_gan()
        self.noisy_samples = NoiseMaker(generator=self.generator,
                                        shape=self.data_shape,
                                        noise_type='s&p')

        self.performance_output_path = 'performance/temp/'
        if not os.path.exists(self.performance_output_path):
            os.makedirs(self.performance_output_path)

    def define_gan(self):
        self.generator = build_generator(input_shape=self.data_shape)
        self.discriminator = build_discriminator(input_shape=self.data_shape)

        self.adversarial = build_adversarial(
            generator_model=self.generator,
            discriminator_model=self.discriminator)

    def train(self, dataset, epochs=100, batch_size=64):

        for e in range(epochs):
            print('Epochs: %3d/%d' % (e, epochs))
            self.single_epoch(dataset, batch_size)
            self.performance(step=e, test_data=dataset.test_data)

    def single_epoch(self, dataset, batch_size):
        half_batch_size = int(batch_size / 2)
        trained_samples = 0

        for realX, _, realY in dataset.iter_samples(half_batch_size):
            fakeX, fakeY, _ = self.noisy_samples.denoise_samples(
                real_samples=realX)
            X = np.vstack([realX, fakeX])
            Y = np.hstack([realY, fakeY])

            discriminator_loss = self.discriminator.train_on_batch(X, Y)

            noisy_input = self.noisy_samples.add_noise(realX)
            act_real = np.ones(shape=(len(noisy_input), ))

            gan_loss = self.adversarial.train_on_batch(noisy_input, act_real)

            trained_samples = min(trained_samples + half_batch_size,
                                  dataset.sample_number)
            print('     %5d/%d -> Discriminator Loss: %f, Gan Loss: %f' %
                  (trained_samples, dataset.sample_number, discriminator_loss,
                   gan_loss))

    def performance(self, step, test_data):

        sub_test_data = test_data[step * 50:(step + 1) * 50]

        # prepare fake examples
        generated, _, noise = self.noisy_samples.denoise_samples(
            real_samples=sub_test_data)

        # save plot to file
        fig_file = self.performance_output_path + 'epoch-%04d_plot.png' % (
            step + 1)
        data_triplet = np.concatenate([sub_test_data, noise, generated],
                                      axis=2)
        plot_images(data_triplet, path=fig_file)

        # save the generator model
        model_file = self.performance_output_path + 'model_%04d.h5' % (step +
                                                                       1)
        self.generator.save(model_file)
        print('>Saved: %s and %s' % (fig_file, model_file))