Ejemplo n.º 1
0
def test_save_load():
    print("========== Test save, load tensorflow models ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_iris()
    print("Number of training samples = {}".format(x_train.shape[0]))
    print("Number of testing samples = {}".format(x_test.shape[0]))

    clf = TensorFlowGLM(model_name="iris_TensorFlowGLM_softmax",
                        link='softmax',
                        loss='softmax',
                        num_epochs=5,
                        random_state=random_seed())

    clf.fit(x_train, y_train)

    print("After training:")
    train_err = 1.0 - clf.score(x_train, y_train)
    test_err = 1.0 - clf.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)

    save_file_path = clf.save()

    clf1 = TensorFlowModel.load_model(save_file_path)

    print("After save and load:")
    train_err1 = 1.0 - clf1.score(x_train, y_train)
    test_err1 = 1.0 - clf1.score(x_test, y_test)
    print("Training error = %.4f" % train_err1)
    print("Testing error = %.4f" % test_err1)
    assert abs(train_err - train_err1) < 1e-6
    assert abs(test_err - test_err1) < 1e-6
Ejemplo n.º 2
0
def test_dfm_save_and_load(show_figure=False, block_figure_on_end=False):
    print("========== Test Save and Load functions of DFM on MNIST data ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[{'metrics': ['d_loss', 'g_loss'],
                                     'type': 'line',
                                     'labels': ["discriminator loss", "generator loss"],
                                     'title': "Losses",
                                     'xlabel': "epoch",
                                     'ylabel': "loss",
                                     },
                                    ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[{'metrics': ['x_samples'],
                                       'title': "Generated data",
                                       'type': 'img',
                                       'num_samples': 100,
                                       'tile_shape': (10, 10),
                                       },
                                      ])

    model = DFM(model_name="DFM_MNIST_SaveLoad",
                num_z=10,  # set to 100 for a full run
                img_size=(28, 28, 1),
                batch_size=32,  # set to 64 for a full run
                num_conv_layers=3,  # set to 3 for a full run
                num_gen_feature_maps=4,  # set to 32 for a full run
                num_dis_feature_maps=4,  # set to 32 for a full run
                alpha=0.03 / 10,  # 0.03 / 1024
                noise_std=1.0,
                num_dfm_layers=1,  # 2
                num_dfm_hidden=10,  # 1024
                metrics=['d_loss', 'g_loss'],
                callbacks=[loss_display, sample_display],
                num_epochs=4,  # set to 100 for a full run
                random_state=random_seed(),
                verbose=1)

    model.fit(x_train)

    save_file_path = model.save()

    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 10
    model1.fit(x_train)
Ejemplo n.º 3
0
def test_continue_training():
    print("========== Test continue training tensorflow models ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_iris()
    print("Number of training samples = {}".format(x_train.shape[0]))
    print("Number of testing samples = {}".format(x_test.shape[0]))

    num_epochs = 5
    clf = TensorFlowGLM(model_name="iris_TensorFlowGLM_softmax",
                        link='softmax',
                        loss='softmax',
                        optimizer='sgd',
                        batch_size=10,
                        num_epochs=num_epochs,
                        random_state=random_seed())

    clf.fit(x_train, y_train)

    print("After training for {0:d} epochs".format(num_epochs))
    train_err = 1.0 - clf.score(x_train, y_train)
    test_err = 1.0 - clf.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)

    clf.num_epochs = 10
    print("Set number of epoch to {0:d}, then continue training...".format(clf.num_epochs))
    clf.fit(x_train, y_train)
    train_err = 1.0 - clf.score(x_train, y_train)
    test_err = 1.0 - clf.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)

    save_file_path = clf.save()
    clf1 = TensorFlowModel.load_model(save_file_path)
    clf1.num_epochs = 15
    print("Save, load, set number of epoch to {0:d}, "
          "then continue training...".format(clf1.num_epochs))
    clf1.fit(x_train, y_train)
    train_err = 1.0 - clf1.score(x_train, y_train)
    test_err = 1.0 - clf1.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)
Ejemplo n.º 4
0
def test_dcgan_cifar10_inception_metric(show_figure=False,
                                        block_figure_on_end=False):
    print(
        "========== Test DCGAN with Inception Score and FID on CIFAR10 data =========="
    )

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/DCGAN/cifar10")
    checkpoints_is = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_is/the_best_is.ckpt"),
                                     mode='max',
                                     monitor='inception_score',
                                     verbose=1,
                                     save_best_only=True)
    checkpoints_fid = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_fid/the_best_fid.ckpt"),
                                      mode='min',
                                      monitor='FID',
                                      verbose=1,
                                      save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    inception_score_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.png"),
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['inception_score'],
                'type': 'line',
                'labels': ["Inception Score"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    fid_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.png"),
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['FID'],
                'type': 'line',
                'labels': ["FID"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(
                                 root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = DCGAN(
        model_name="DCGAN_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=16,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=[
            'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID'
        ],
        callbacks=[
            loss_display, inception_score_display, fid_display, sample_display,
            checkpoints_is, checkpoints_fid
        ],
        num_epochs=4,  # set to 100 for a full run
        inception_metrics=[InceptionScore(),
                           FID(data="cifar10")],
        inception_metrics_freq=1,
        # summary_freq=1,  # uncomment this for a full run
        log_path=os.path.join(root_dir, "logs"),
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)
    filepath = os.path.join(root_dir, "checkpoints_fid/the_best_fid.ckpt")
    print("Reloading the latest model at: {}".format(filepath))
    model1 = TensorFlowModel.load_model(filepath)
    model1.inception_metrics = InceptionMetricList(
        [InceptionScore(), FID(data="cifar10")])
    model1.num_epochs = 6
    model1.fit(x_train)
    print("Done!")
Ejemplo n.º 5
0
def test_dcgan_save_and_load(show_figure=False, block_figure_on_end=False):
    print(
        "========== Test Save and Load functions of DCGAN on MNIST data =========="
    )

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/DCGAN/MNIST")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = DCGAN(
        model_name="DCGAN_MNIST_SaveLoad",
        num_z=10,  # set to 100 for a full run
        img_size=(28, 28, 1),
        batch_size=16,  # set to 64 for a full run
        num_conv_layers=3,  # set to 3 for a full run
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[loss_display, sample_display],
        num_epochs=2,  # set to 100 for a full run
        log_path=os.path.join(root_dir, "logs"),
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)

    print("Saving model...")
    save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt"))
    print("Reloading model...")
    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 4
    model1.fit(x_train)
    print("Done!")
Ejemplo n.º 6
0
def test_wgan_gp_resnet_cifar10(show_figure=False, block_figure_on_end=False):
    print("========== Test WGAN-GP-ResNet on CIFAR10 data ==========")

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = WGAN_GP_ResNet(
        model_name="WGAN_GP_ResNet_CIFAR10",
        num_z=10,  # set to 100 for a full run
        img_size=(32, 32, 3),
        batch_size=16,  # set to 64 for a full run
        g_blocks=('up', 'up', 'up'),
        d_blocks=('down', 'down', None, None),
        num_gen_feature_maps=4,  # set to 32 for a full run
        num_dis_feature_maps=4,  # set to 32 for a full run
        metrics=['d_loss', 'g_loss'],
        callbacks=[loss_display, sample_display],
        num_epochs=2,  # set to 100 for a full run
        random_state=random_seed(),
        log_path=os.path.join(root_dir, "logs"),
        verbose=1)

    model.fit(x_train)

    print("Saving model...")
    save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt"))
    print("Reloading model...")
    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 4
    model1.fit(x_train)
    print("Done!")
Ejemplo n.º 7
0
def test_wgan_gp_resnet_cifar10_inception_metric(show_figure=False,
                                                 block_figure_on_end=False):
    print(
        "========== Test WGAN-GP-ResNet with Inception Score and FID on CIFAR10 data =========="
    )

    np.random.seed(random_seed())

    num_data = 128
    (x_train, y_train), (x_test, y_test) = demo.load_cifar10()
    x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3
                                                             ]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1.

    # uncomment for full run
    '''
    import pickle
    from male.configs import data_dir
    tmp = pickle.load(open(os.path.join(data_dir(), "cifar10/cifar10_train.pkl"), "rb"))
    x_train = tmp['data'].astype(np.float32).reshape(
        [-1, 32, 32, 3]) / 127.5 - 1.
    '''

    root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10")
    checkpoints_is = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_is/the_best_is.ckpt"),
                                     mode='max',
                                     monitor='inception_score',
                                     verbose=1,
                                     save_best_only=True)
    checkpoints_fid = ModelCheckpoint(os.path.join(
        root_dir, "checkpoints_fid/the_best_fid.ckpt"),
                                      mode='min',
                                      monitor='FID',
                                      verbose=1,
                                      save_best_only=True)
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           filepath=[
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.png"),
                               os.path.join(root_dir,
                                            "loss/loss_{epoch:04d}.pdf")
                           ],
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    inception_score_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.png"),
            os.path.join(root_dir, "inception_score/"
                         "inception_score_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['inception_score'],
                'type': 'line',
                'labels': ["Inception Score"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    fid_display = Display(
        layout=(1, 1),
        dpi='auto',
        show=show_figure,
        block_on_end=block_figure_on_end,
        filepath=[
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.png"),
            os.path.join(root_dir, "FID/"
                         "FID_{epoch:04d}.pdf")
        ],
        monitor=[
            {
                'metrics': ['FID'],
                'type': 'line',
                'labels': ["FID"],
                'title': "Scores",
                'xlabel': "epoch",
                'ylabel': "score",
            },
        ],
    )
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             filepath=os.path.join(
                                 root_dir, "samples/samples_{epoch:04d}.png"),
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = WGAN_GP_ResNet(
        model_name="WGAN_GP_ResNet_CIFAR10",
        num_z=8,  # set to 128 for a full run
        img_size=(32, 32, 3),
        batch_size=64,  # set to 64 for a full run
        g_blocks=('up', 'up', 'up'),
        d_blocks=('down', 'down', None, None),
        num_gen_feature_maps=8,  # set to 128 for a full run
        num_dis_feature_maps=8,  # set to 128 for a full run
        metrics=[
            'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID'
        ],
        callbacks=[
            loss_display, inception_score_display, fid_display, sample_display,
            checkpoints_is, checkpoints_fid
        ],
        num_epochs=2,  # set to 500 for a full run
        inception_metrics=[InceptionScore(),
                           FID(data='cifar10')],
        inception_metrics_freq=1,
        num_inception_samples=100,  # set to 50000 for a full run
        # summary_freq=1,  # uncomment this for a full run
        log_path=os.path.join(root_dir, 'logs'),
        random_state=random_seed(),
        verbose=1)

    model.fit(x_train)
    filepath = os.path.join(root_dir, 'checkpoints_fid/the_best_fid.ckpt')
    print('Reloading the latest model at: {}'.format(filepath))
    model1 = TensorFlowModel.load_model(filepath)
    model1.inception_metrics = InceptionMetricList(
        [InceptionScore(), FID(data='cifar10')])
    model1.num_epochs = 4
    model1.fit(x_train)
    print('Done!')
Ejemplo n.º 8
0
def test_tfglm_save_load(show=False, block_figure_on_end=False):
    print(
        "========== Test Save and Load functions for TensorFlowGLM ==========")

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_iris()
    print("Number of training samples = {}".format(x_train.shape[0]))
    print("Number of testing samples = {}".format(x_test.shape[0]))

    x = np.vstack([x_train, x_test])
    y = np.concatenate([y_train, y_test])

    early_stopping = EarlyStopping(monitor='val_err', patience=5, verbose=1)
    filepath = os.path.join(
        model_dir(), "male/TensorFlowGLM/iris_{epoch:04d}_{val_err:.6f}.pkl")
    checkpoint = ModelCheckpoint(filepath,
                                 mode='min',
                                 monitor='val_err',
                                 verbose=0,
                                 save_best_only=True)
    loss_display = Display(title="Learning curves",
                           dpi='auto',
                           layout=(3, 1),
                           freq=1,
                           show=show,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['loss', 'val_loss'],
                                   'type': 'line',
                                   'labels':
                                   ["training loss", "validation loss"],
                                   'title': "Learning losses",
                                   'xlabel': "epoch",
                                   'ylabel': "loss",
                               },
                               {
                                   'metrics': ['err', 'val_err'],
                                   'type': 'line',
                                   'title': "Learning errors",
                                   'xlabel': "epoch",
                                   'ylabel': "error",
                               },
                               {
                                   'metrics': ['err'],
                                   'type': 'line',
                                   'labels': ["training error"],
                                   'title': "Learning errors",
                                   'xlabel': "epoch",
                                   'ylabel': "error",
                               },
                           ])

    weight_display = Display(title="Filters",
                             dpi='auto',
                             layout=(1, 1),
                             figsize=(6, 15),
                             freq=1,
                             show=show,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['weights'],
                                     'title': "Learned weights",
                                     'type': 'img',
                                     'disp_dim': (2, 2),
                                     'tile_shape': (3, 1),
                                 },
                             ])

    clf = TensorFlowGLM(
        model_name="TensorFlowGLM_softmax_cv",
        link='softmax',
        loss='softmax',
        optimizer='sgd',
        num_epochs=4,
        batch_size=10,
        task='classification',
        metrics=['loss', 'err'],
        callbacks=[early_stopping, checkpoint, loss_display, weight_display],
        cv=[-1] * x_train.shape[0] + [0] * x_test.shape[0],
        random_state=random_seed(),
        verbose=1)

    clf.fit(x, y)

    train_err = 1.0 - clf.score(x_train, y_train)
    test_err = 1.0 - clf.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)

    save_file_path = os.path.join(model_dir(),
                                  "male/TensorFlowGLM/saved_model.ckpt")
    clf.save(file_path=save_file_path)
    clf1 = TensorFlowModel.load_model(save_file_path)

    clf1.num_epochs = 10
    clf1.fit(x, y)

    train_err = 1.0 - clf1.score(x_train, y_train)
    test_err = 1.0 - clf1.score(x_test, y_test)
    print("Training error = %.4f" % train_err)
    print("Testing error = %.4f" % test_err)
Ejemplo n.º 9
0
def test_gan_save_and_load(show_figure=False, block_figure_on_end=False):
    print(
        "========== Test Save and Load functions of GAN on MNIST data =========="
    )

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()
    x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.
    x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1.

    root_dir = os.path.join(model_dir(), "male/GAN/MNIST")
    loss_display = Display(layout=(1, 1),
                           dpi='auto',
                           show=show_figure,
                           block_on_end=block_figure_on_end,
                           monitor=[
                               {
                                   'metrics': ['d_loss', 'g_loss'],
                                   'type':
                                   'line',
                                   'labels':
                                   ["discriminator loss", "generator loss"],
                                   'title':
                                   "Losses",
                                   'xlabel':
                                   "epoch",
                                   'ylabel':
                                   "loss",
                               },
                           ])
    sample_display = Display(layout=(1, 1),
                             dpi='auto',
                             figsize=(10, 10),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['x_samples'],
                                     'title': "Generated data",
                                     'type': 'img',
                                     'num_samples': 100,
                                     'tile_shape': (10, 10),
                                 },
                             ])

    model = GAN(model_name="GAN_MNIST_SaveLoad",
                num_x=784,
                num_discriminator_hiddens=(16, ),
                discriminator_batchnorm=False,
                discriminator_act_funcs=('lrelu', ),
                discriminator_learning_rate=0.001,
                num_z=8,
                generator_distribution=Uniform(low=(-1.0, ) * 8,
                                               high=(1.0, ) * 8),
                generator_batchnorm=False,
                num_generator_hiddens=(16, ),
                generator_act_funcs=('lrelu', ),
                generator_out_func='sigmoid',
                generator_learning_rate=0.001,
                batch_size=32,
                metrics=['d_loss', 'g_loss'],
                callbacks=[loss_display, sample_display],
                num_epochs=2,
                log_path=os.path.join(root_dir, "logs"),
                random_state=random_seed(),
                verbose=1)

    model.fit(x_train)

    print("Saving model...")
    save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt"))
    print("Reloading model...")
    model1 = TensorFlowModel.load_model(save_file_path)
    model1.num_epochs = 4
    model1.fit(x_train)
    print("Done!")