Example #1
0
def test_inception_score():
    print("========== Test Inception score ==========")

    (x_train, y_train), (_, _) = demo.load_cifar10()
    x_train = x_train[:100].astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0
    score = InceptionScore().score(x_train)
    print("Inception score: {:.4f}+-{:.4f}".format(score[0], score[1]))
Example #2
0
def test_cganv1_cifar10(show_figure=False, block_figure_on_end=False):
    print("========== Test CGANv1 on CIFAR10 data ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "CGANv1 - Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "CGANv1 - Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = CGANv1(
        model_name="CGANv1_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=64,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=2,  # set to 32 for a full run
        num_dis_feature_maps=2,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[loss_display, sample_display],
        num_epochs=1,  # set to 100 for a full run
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train, y_train)
Example #3
0
def test_frechet_inception_distance():
    print("========== Test Frechet Inception Distance (FID) ==========")

    (x_train, y_train), (_, _) = demo.load_cifar10()
    x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0
    score = FID(data="cifar10").score(x_train[:100])
    print("Case #0: FID = {:.4f}".format(score))
    score = FID(data=x_train[:200]).score(x_train[:100])
    print("Case #1: FID = {:.4f}".format(score))
Example #4
0
def test_wgan_cifar10(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN on CIFAR10 data ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN/CIFAR10")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"),
                                     os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")],
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = WGAN(model_name="WGAN_CIFAR10",
                 num_z=10,  # set to 100 for a full run
                 z_prior=Uniform1D(low=-1.0, high=1.0),
                 img_size=(32, 32, 3),
                 batch_size=16,  # set to 64 for a full run
                 num_conv_layers=3,  # set to 3 for a full run
                 num_gen_feature_maps=4,  # set to 32 for a full run
                 num_dis_feature_maps=4,  # set to 32 for a full run
                 metrics=['d_loss', 'g_loss'],
                 callbacks=[loss_display, sample_display],
                 num_epochs=4,  # set to 100 for a full run
                 log_path=os.path.join(root_dir, "logs"),
                 random_state=random_seed(),
                 verbose=1)

    model.fit(x_train)
Example #5
0
def test_inception_metric_list():
    print("========== Test Frechet Inception Distance (FID) ==========")

    (x_train, y_train), (_, _) = demo.load_cifar10()
    x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0
    scores = InceptionMetricList(
        [InceptionScore(),
         FID(data="cifar10"),
         FID(data=x_train[:200])]).score(x_train[:100])
    for (i, s) in enumerate(scores):
        if isinstance(s, tuple):
            print("Case #{}: score = {:.4f}+-{:.4f}".format(i, s[0], s[1]))
        else:
            print("Case #{}: score = {:.4f}".format(i, s))
Example #6
0
def test_image_saver_callback():
    np.random.seed(random_seed())

    (x_train, y_train), (_, _) = demo.load_mnist()
    (cifar10_train, _), (_, _) = demo.load_cifar10()

    imgsaver1 = ImageSaver(freq=1,
                           filepath=os.path.join(
                               model_dir(), "male/callbacks/imagesaver/"
                               "mnist/mnist_{epoch:04d}.png"),
                           monitor={
                               'metrics': 'x_data',
                               'img_size': (28, 28, 1),
                               'tile_shape': (10, 10),
                               'images': x_train[:100].reshape([-1, 28, 28, 1])
                           })
    imgsaver2 = ImageSaver(freq=1,
                           filepath=os.path.join(
                               model_dir(), "male/callbacks/imagesaver/"
                               "cifar10/cifar10_{epoch:04d}.png"),
                           monitor={
                               'metrics':
                               'x_data',
                               'img_size': (32, 32, 3),
                               'tile_shape': (10, 10),
                               'images':
                               cifar10_train[:100].reshape([-1, 32, 32, 3])
                           })

    optz = SGD(learning_rate=0.001)
    clf = GLM(model_name="imagesaver_callback",
              link='softmax',
              loss='softmax',
              optimizer=optz,
              num_epochs=4,
              batch_size=100,
              task='classification',
              callbacks=[imgsaver1, imgsaver2],
              random_state=random_seed(),
              verbose=1)
    clf.fit(x_train, y_train)
Example #7
0
def test_dcgan_image_saver():
    print("========== Test DCGAN with Image Saver ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 28, 28, 1
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/DCGAN/imagesaver/mnist")
    imgsaver = ImageSaver(freq=1,
                          filepath=os.path.join(root_dir,
                                                "mnist_{epoch:04d}.png"),
                          monitor={
                              'metrics': 'x_samples',
                              'num_samples': 100,
                              'tile_shape': (10, 10),
                          })

    model = DCGAN(
        model_name="DCGAN_MNIST",
        num_z=10,  # set to 100 for a full run
        img_size=(28, 28, 1),
        batch_size=16,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[imgsaver],
        num_epochs=4,  # set to 100 for a full run
        random_state=random_seed(),
        log_path=os.path.join(root_dir, "logs"),
        verbose=1)

    model.fit(x_train)

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/DCGAN/imagesaver/cifar10")
    imgsaver = ImageSaver(freq=1,
                          filepath=os.path.join(root_dir,
                                                "cifar10_{epoch:04d}.png"),
                          monitor={
                              'metrics': 'x_samples',
                              'num_samples': 100,
                              'tile_shape': (10, 10),
                          })

    model = DCGAN(
        model_name="DCGAN_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=16,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[imgsaver],
        num_epochs=4,  # set to 100 for a full run
        random_state=random_seed(),
        log_path=os.path.join(root_dir, "logs"),
        verbose=1)

    model.fit(x_train)
Example #8
0
def test_dcgan_cifar10_inception_metric(show_figure=False,
                                        block_figure_on_end=False):
    print(
        "========== Test DCGAN with Inception Score and FID on CIFAR10 data =========="
    )

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/DCGAN/cifar10")
    checkpoints_is = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_is/the_best_is.ckpt"),
                                     mode='max',
                                     monitor='inception_score',
                                     verbose=1,
                                     save_best_only=True)
    checkpoints_fid = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_fid/the_best_fid.ckpt"),
                                      mode='min',
                                      monitor='FID',
                                      verbose=1,
                                      save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    inception_score_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.png"),
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['inception_score'],
                'type': 'line',
                'labels': ["Inception Score"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    fid_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.png"),
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['FID'],
                'type': 'line',
                'labels': ["FID"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(
                                 root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = DCGAN(
        model_name="DCGAN_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=16,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=[
            'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID'
        ],
        callbacks=[
            loss_display, inception_score_display, fid_display, sample_display,
            checkpoints_is, checkpoints_fid
        ],
        num_epochs=4,  # set to 100 for a full run
        inception_metrics=[InceptionScore(),
                           FID(data="cifar10")],
        inception_metrics_freq=1,
        # summary_freq=1,  # uncomment this for a full run
        log_path=os.path.join(root_dir, "logs"),
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)
    filepath = os.path.join(root_dir, "checkpoints_fid/the_best_fid.ckpt")
    print("Reloading the latest model at: {}".format(filepath))
    model1 = TensorFlowModel.load_model(filepath)
    model1.inception_metrics = InceptionMetricList(
        [InceptionScore(), FID(data="cifar10")])
    model1.num_epochs = 6
    model1.fit(x_train)
    print("Done!")
Example #9
0
def test_wgan_gp_cifar10_fid(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN-GP with FID on CIFAR10 data ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN-GP/CIFAR10")
    checkpoints = ModelCheckpoint(
        os.path.join(root_dir, "checkpoints/{epoch:04d}_{FID:.6f}.ckpt"),
        mode='min',
        monitor='FID',
        verbose=1,
        save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"),
                                     os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")],
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    fid_display = Display(layout=(1, 1),
                          dpi='auto',
                          show=show_figure,
                          block_on_end=block_figure_on_end,
                          filepath=[os.path.join(root_dir, "FID/"
                                                           "FID_{epoch:04d}.png"),
                                    os.path.join(root_dir, "FID/"
                                                           "FID_{epoch:04d}.pdf")],
                          monitor=[{'metrics': ['FID'],
                                    'type': 'line',
                                    'labels': ["FID"],
                                    'title': "Scores",
                                    'xlabel': "epoch",
                                    'ylabel': "score",
                                    },
                                   ],
                          )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(root_dir, "samples/"
                                                             "samples_{epoch:04d}.png"),
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = WGAN_GP(model_name="WGAN_GP_CIFAR10",
                    num_z=10,  # set to 100 for a full run
                    img_size=(32, 32, 3),
                    batch_size=16,  # set to 64 for a full run
                    num_conv_layers=3,  # set to 3 for a full run
                    num_gen_feature_maps=4,  # set to 32 for a full run
                    num_dis_feature_maps=4,  # set to 32 for a full run
                    metrics=['d_loss', 'g_loss', 'FID', 'FID_100points'],
                    callbacks=[loss_display, fid_display,
                               sample_display, checkpoints],
                    num_epochs=4,  # set to 100 for a full run
                    inception_metrics=[FID(data="cifar10"),
                                       FID(name="FID_100points", data=x_train[:100])],
                    inception_metrics_freq=1,
                    # summary_freq=1,  # uncomment this for a full run
                    log_path=os.path.join(root_dir, "logs"),
                    random_state=random_seed(),
                    verbose=1)

    model.fit(x_train)
Example #10
0
def test_gank_image_saver():
    print("========== Test GANK-Logit with Image Saver ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    imgsaver = ImageSaver(freq=1,
                          filepath=os.path.join(
                              model_dir(), "male/GANK/imagesaver/"
                              "mnist/mnist_{epoch:04d}.png"),
                          monitor={
                              'metrics': 'x_samples',
                              'num_samples': 100,
                              'tile_shape': (10, 10),
                          })

    model = GANK(
        model_name="GANK_MNIST",
        num_random_features=50,  # set to 1000 for a full run
        gamma_init=0.01,
        loss='logit',
        num_z=10,  # set to 100 for a full run
        img_size=(28, 28, 1),
        batch_size=32,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[imgsaver],
        num_epochs=4,  # set to 100 for a full run
        random_state=random_seed(),
        verbose=1)
    model.fit(x_train)

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    imgsaver = ImageSaver(freq=1,
                          filepath=os.path.join(
                              model_dir(), "male/GANK/imagesaver/"
                              "cifar10/cifar10_{epoch:04d}.png"),
                          monitor={
                              'metrics': 'x_samples',
                              'num_samples': 100,
                              'tile_shape': (10, 10),
                          })
    model = GANK(
        model_name="GANK_CIFAR10",
        num_random_features=50,  # set 1000 for a full run
        gamma_init=0.01,
        loss='logit',
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=32,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[imgsaver],
        num_epochs=4,  # set to 500 for a full run
        random_state=random_seed(),
        verbose=1)
    model.fit(x_train)
Example #11
0
def test_gank_logit_cifar10_inception_score(show_figure=False,
                                            block_figure_on_end=False):
    print(
        "========== Test GANK-Logit with Inception Score on CIFAR10 data =========="
    )

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    filepath = os.path.join(
        model_dir(), "male/GANK/Logit/cifar10/checkpoints/"
        "{epoch:04d}_{inception_score:.6f}.ckpt")
    checkpoint = ModelCheckpoint(filepath,
                                 mode='max',
                                 monitor='inception_score',
                                 verbose=0,
                                 save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(
                                   model_dir(), "male/GANK/Logit/cifar10/"
                                   "loss/loss_{epoch:04d}.png"),
                               os.path.join(
                                   model_dir(), "male/GANK/Logit/cifar10/"
                                   "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    inception_score_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(
                model_dir(), "male/GANK/Logit/cifar10/inception_score/"
                "inception_score_{epoch:04d}.png"),
            os.path.join(
                model_dir(), "male/GANK/Logit/cifar10/inception_score/"
                "inception_score_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['inception_score'],
                'type': 'line',
                'labels': ["Inception Score"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(
                                 model_dir(),
                                 "male/GANK/Logit/cifar10/samples/"
                                 "samples_{epoch:04d}.png"),
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = GANK(
        model_name="GANK-Logit_CIFAR10",
        num_random_features=50,  # set 1000 for a full run
        gamma_init=0.01,
        loss='logit',
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=32,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss', 'inception_score'],
        # callbacks=[loss_display, inception_score_display, sample_display, checkpoint],
        callbacks=[checkpoint],
        num_epochs=4,  # set to 500 for a full run
        inception_score_freq=1,
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)
Example #12
0
def test_wgan_gp_resnet_cifar10(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN-GP-ResNet on CIFAR10 data ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = WGAN_GP_ResNet(
        model_name="WGAN_GP_ResNet_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=16,  # set to 64 for a full run
        g_blocks=('up', 'up', 'up'),
        d_blocks=('down', 'down', None, None),
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[loss_display, sample_display],
        num_epochs=2,  # set to 100 for a full run
        random_state=random_seed(),
        log_path=os.path.join(root_dir, "logs"),
        verbose=1)

    model.fit(x_train)

    print("Saving model...")
    save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt"))
    print("Reloading model...")
    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 4
    model1.fit(x_train)
    print("Done!")
Example #13
0
def test_wgan_gp_resnet_cifar10_inception_metric(show_figure=False,
                                                 block_figure_on_end=False):
    print(
        "========== Test WGAN-GP-ResNet with Inception Score and FID on CIFAR10 data =========="
    )

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    # uncomment for full run
    '''
    import pickle
    from male.configs import data_dir
    tmp = pickle.load(open(os.path.join(data_dir(), "cifar10/cifar10_train.pkl"), "rb"))
    x_train = tmp['data'].astype(np.float32).reshape(
        [-1, 32, 32, 3]) / 127.5 - 1.
    '''

    root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10")
    checkpoints_is = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_is/the_best_is.ckpt"),
                                     mode='max',
                                     monitor='inception_score',
                                     verbose=1,
                                     save_best_only=True)
    checkpoints_fid = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_fid/the_best_fid.ckpt"),
                                      mode='min',
                                      monitor='FID',
                                      verbose=1,
                                      save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    inception_score_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.png"),
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['inception_score'],
                'type': 'line',
                'labels': ["Inception Score"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    fid_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.png"),
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['FID'],
                'type': 'line',
                'labels': ["FID"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(
                                 root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = WGAN_GP_ResNet(
        model_name="WGAN_GP_ResNet_CIFAR10",
        num_z=8,  # set to 128 for a full run
        img_size=(32, 32, 3),
        batch_size=64,  # set to 64 for a full run
        g_blocks=('up', 'up', 'up'),
        d_blocks=('down', 'down', None, None),
        num_gen_feature_maps=8,  # set to 128 for a full run
        num_dis_feature_maps=8,  # set to 128 for a full run
        metrics=[
            'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID'
        ],
        callbacks=[
            loss_display, inception_score_display, fid_display, sample_display,
            checkpoints_is, checkpoints_fid
        ],
        num_epochs=2,  # set to 500 for a full run
        inception_metrics=[InceptionScore(),
                           FID(data='cifar10')],
        inception_metrics_freq=1,
        num_inception_samples=100,  # set to 50000 for a full run
        # summary_freq=1,  # uncomment this for a full run
        log_path=os.path.join(root_dir, 'logs'),
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)
    filepath = os.path.join(root_dir, 'checkpoints_fid/the_best_fid.ckpt')
    print('Reloading the latest model at: {}'.format(filepath))
    model1 = TensorFlowModel.load_model(filepath)
    model1.inception_metrics = InceptionMetricList(
        [InceptionScore(), FID(data='cifar10')])
    model1.num_epochs = 4
    model1.fit(x_train)
    print('Done!')