Beispiel #1
0
    def _init(self):
        super(GRN1D, self)._init()

        if self.aux_discriminators is None:
            self.aux_discriminators = [
                Uniform1D(low=-5, high=-1),
                Uniform1D(low=1, high=5)
            ]
        self.num_aux = len(
            self.aux_discriminators)  # number of auxiliary distributions
        if self.aux_coeffs is None:
            self.aux_coeffs = [0.1] * self.num_aux
Beispiel #2
0
def test_cgan_mnist(show_figure=False, block_figure_on_end=False):
    print("========== Test CGAN on MNIST data ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = CGAN(
        model_name="CGAN_MNIST",
        num_z=10,  # set to 100 for a full run
        z_prior=Uniform1D(low=-1.0, high=1.0),
        img_size=(28, 28, 1),
        batch_size=64,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=2,  # set to 32 for a full run
        num_dis_feature_maps=2,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[loss_display, sample_display],
        num_epochs=1,  # set to 100 for a full run
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train, y_train)
Beispiel #3
0
def test_wgan_cifar10(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN on CIFAR10 data ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN/CIFAR10")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"),
                                     os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")],
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = WGAN(model_name="WGAN_CIFAR10",
                 num_z=10,  # set to 100 for a full run
                 z_prior=Uniform1D(low=-1.0, high=1.0),
                 img_size=(32, 32, 3),
                 batch_size=16,  # set to 64 for a full run
                 num_conv_layers=3,  # set to 3 for a full run
                 num_gen_feature_maps=4,  # set to 32 for a full run
                 num_dis_feature_maps=4,  # set to 32 for a full run
                 metrics=['d_loss', 'g_loss'],
                 callbacks=[loss_display, sample_display],
                 num_epochs=4,  # set to 100 for a full run
                 log_path=os.path.join(root_dir, "logs"),
                 random_state=random_seed(),
                 verbose=1)

    model.fit(x_train)
Beispiel #4
0
def test_wgan_gp_mnist(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN-GP on MNIST data ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN-GP/MNIST")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"),
                                     os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")],
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = WGAN_GP(model_name="WGAN_GP_MNIST_z_uniform",
                    num_z=10,  # set to 100 for a full run
                    z_prior=Uniform1D(low=-1.0, high=1.0),
                    img_size=(28, 28, 1),
                    batch_size=16,  # set to 64 for a full run
                    num_conv_layers=3,  # set to 3 for a full run
                    num_gen_feature_maps=4,  # set to 32 for a full run
                    num_dis_feature_maps=4,  # set to 32 for a full run
                    metrics=['d_loss', 'g_loss'],
                    callbacks=[loss_display, sample_display],
                    num_epochs=4,  # set to 100 for a full run
                    # summary_freq=1,  # uncomment this for a full run
                    random_state=random_seed(),
                    log_path=os.path.join(root_dir, "logs"),
                    verbose=1)

    model.fit(x_train)

    model = WGAN_GP(model_name="WGAN_GP_MNIST_z_Gaussian",
                    num_z=10,  # set to 100 for a full run
                    z_prior=Gaussian1D(mu=0.0, sigma=1.0),
                    img_size=(28, 28, 1),
                    batch_size=32,  # set to 64 for a full run
                    num_conv_layers=3,  # set to 3 for a full run
                    num_gen_feature_maps=4,  # set to 32 for a full run
                    num_dis_feature_maps=4,  # set to 32 for a full run
                    metrics=['d_loss', 'g_loss'],
                    callbacks=[loss_display, sample_display],
                    num_epochs=4,  # set to 100 for a full run
                    # summary_freq=1,  # uncomment this for a full run
                    random_state=random_seed(),
                    log_path=os.path.join(root_dir, "logs"),
                    verbose=1)

    model.fit(x_train)
Beispiel #5
0
def test_d2gan1d_gaussian1d(show_figure=False, block_figure_on_end=False):
    loss_display = Display(layout=(1, 3),
                           dpi='auto',
                           title='Loss',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                               {
                                   'metrics': ['loglik'],
                                   'type': 'line',
                                   'labels': ["Log-likelihood"],
                                   'title': "Evaluation",
                                   'xlabel': "epoch",
                                   'ylabel': "loglik",
                               },
                               {
                                   'metrics': ['d1x', 'd2x'],
                                   'type': 'line',
                                   'labels': ["d1x_mean", "d2x_mean"],
                                   'title': "Discriminative scores",
                                   'xlabel': "epoch",
                                   'ylabel': "score",
                               },
                           ])
    distribution_display = Display(layout=(1, 1),
                                   dpi='auto',
                                   freq=10,
                                   title='Histogram',
                                   show=show_figure,
                                   block_on_end=block_figure_on_end,
                                   monitor=[
                                       {
                                           'metrics': ['distribution'],
                                           'type': 'hist',
                                           'title': "Histogram of D2GAN1D",
                                           'xlabel': "Data values",
                                           'ylabel': "Probability density",
                                       },
                                   ])
    avg_distribution_display = Display(layout=(1, 1),
                                       dpi='auto',
                                       freq=10,
                                       title='Average Histogram',
                                       show=show_figure,
                                       block_on_end=block_figure_on_end,
                                       monitor=[
                                           {
                                               'metrics': ['avg_distribution'],
                                               'type': 'hist',
                                               'title':
                                               "Averaged Histogram of D2GAN1D",
                                               'xlabel': "Data values",
                                               'ylabel': "Probability density",
                                           },
                                       ])

    model = D2GAN1D(
        data=Gaussian1D(mu=4.0, sigma=0.5),
        generator=Uniform1D(low=-8.0, high=8.0),
        alpha=1.0,
        beta=1.0,
        num_z=10,  # increase this to 100 for a full run
        num_epochs=4,  # increase this to 1000 for a full run
        hidden_size=16,  # increase this to 128 for a full run
        batch_size=16,
        minibatch_discriminator=False,
        loglik_freq=1,
        generator_learning_rate=0.0001,
        discriminator_learning_rate=0.0001,
        metrics=[
            'd1x', 'd2x', 'd1_loss', 'd2_loss', 'd_loss', 'g_loss', 'loglik'
        ],
        callbacks=[
            loss_display, distribution_display, avg_distribution_display
        ],
        random_state=random_seed(),
        verbose=1)
    model.fit()
Beispiel #6
0
def test_gan1d_gaussian1d(show_figure=False, block_figure_on_end=False):
    print(
        "========== Test GAN on 1D data generated from a Gaussian distribution =========="
    )

    loss_display = Display(layout=(2, 1),
                           dpi='auto',
                           title='Loss',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                               {
                                   'metrics': ['loglik'],
                                   'type': 'line',
                                   'labels': ["Log-likelihood"],
                                   'title': "Evaluation",
                                   'xlabel': "epoch",
                                   'ylabel': "loglik",
                               },
                           ])
    distribution_display = Display(
        layout=(1, 1),
        dpi='auto',
        freq=1,
        title='Histogram',
        show=show_figure,
        block_on_end=block_figure_on_end,
        # filepath=[os.path.join(model_dir(),
        #                        "GAN1D/samples/hist_{epoch:04d}.png")],
        monitor=[
            {
                'metrics': ['distribution'],
                'type': 'hist',
                'title': "Histogram of GAN1D",
                'xlabel': "Data values",
                'ylabel': "Probability density",
            },
        ])
    avg_distribution_display = Display(layout=(1, 1),
                                       dpi='auto',
                                       freq=1,
                                       title='Average Histogram',
                                       show=show_figure,
                                       block_on_end=block_figure_on_end,
                                       monitor=[
                                           {
                                               'metrics': ['avg_distribution'],
                                               'type': 'hist',
                                               'title':
                                               "Averaged Histogram of GAN1D",
                                               'xlabel': "Data values",
                                               'ylabel': "Probability density",
                                           },
                                       ])

    model = GAN1D(
        data=Gaussian1D(mu=4.0, sigma=0.5),
        generator=Uniform1D(low=-8.0, high=8.0),
        num_z=8,  # set to 100 for a full run
        num_epochs=4,  # set to 355 for a full run
        hidden_size=4,  # set to 128 for a full run
        batch_size=8,  # set to 128 for a full run
        minibatch_discriminator=False,
        loglik_freq=1,
        generator_learning_rate=0.0001,
        discriminator_learning_rate=0.0001,
        metrics=['d_loss', 'g_loss', 'loglik'],
        callbacks=[
            loss_display, distribution_display, avg_distribution_display
        ],
        random_state=random_seed(),
        verbose=1)
    model.fit()
Beispiel #7
0
def test_wgan_save_and_load(show_figure=False, block_figure_on_end=False):
    print("========== Test Save and Load functions of WGAN on MNIST data ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    num_data = 128
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN/MNIST")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = WGAN(model_name="WGAN_MNIST_SaveLoad",
                 num_z=8,
                 z_prior=Uniform1D(low=-1.0, high=1.0),
                 img_size=(28, 28, 1),
                 batch_size=16,
                 num_conv_layers=3,
                 num_gen_feature_maps=4,
                 num_dis_feature_maps=4,
                 metrics=['d_loss', 'g_loss'],
                 callbacks=[loss_display, sample_display],
                 num_epochs=2,
                 log_path=os.path.join(root_dir, "logs"),
                 random_state=random_seed(),
                 verbose=1)

    model.fit(x_train)

    print("Saving model...")
    save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt"))
    print("Reloading model...")
    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 4
    model1.fit(x_train)
    print("Done!")
Beispiel #8
0
def test_grn1d_gaussian1d(block_figure_on_end=False):
    print(
        "========== Test GRN on 1D data generated from a Gaussian distribution =========="
    )

    loss_display = Display(
        layout=(2, 1),
        block_on_end=block_figure_on_end,
        monitor=[
            {
                'metrics': ['d_loss', 'g_loss', 'a_loss_1'],
                'type':
                'line',
                'labels':
                ["discriminator loss", "generator loss", "auxiliary_1 loss"],
                'title':
                "Losses",
                'xlabel':
                "epoch",
                'ylabel':
                "loss",
            },
            {
                'metrics': ['loglik'],
                'type': 'line',
                'labels': ["Log-likelihood"],
                'title': "Evaluation",
                'xlabel': "epoch",
                'ylabel': "loglik",
            },
        ])
    distribution_display = Display(layout=(1, 1),
                                   freq=1,
                                   block_on_end=block_figure_on_end,
                                   monitor=[
                                       {
                                           'metrics': ['distribution'],
                                           'type': 'hist',
                                           'title': "Histograms of GRN1D",
                                           'xlabel': "Data values",
                                           'ylabel': "Probability density",
                                       },
                                   ])

    avg_distribution_display = Display(layout=(1, 1),
                                       freq=1,
                                       block_on_end=block_figure_on_end,
                                       monitor=[
                                           {
                                               'metrics': ['avg_distribution'],
                                               'type': 'hist',
                                               'title':
                                               "Average Histograms of GRN1D",
                                               'xlabel': "Data values",
                                               'ylabel': "Probability density",
                                           },
                                       ])

    model = GRN1D(data=Gaussian1D(mu=4.0, sigma=0.5),
                  generator=Uniform1D(low=-8, high=8.0),
                  aux_discriminators=[Gaussian1D(mu=0.0, sigma=0.5)],
                  aux_coeffs=[0.5],
                  num_epochs=4,
                  hidden_size=20,
                  batch_size=12,
                  aux_batch_size=10,
                  discriminator_learning_rate=0.001,
                  generator_learning_rate=0.001,
                  aux_learning_rate=0.001,
                  loglik_freq=1,
                  metrics=['d_loss', 'g_loss', 'a_loss_1', 'loglik'],
                  callbacks=[
                      loss_display, distribution_display,
                      avg_distribution_display
                  ],
                  random_state=random_seed(),
                  verbose=1)
    model.fit()