def test_inception_score(): print("========== Test Inception score ==========") (x_train, y_train), (_, _) = demo.load_cifar10() x_train = x_train[:100].astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0 score = InceptionScore().score(x_train) print("Inception score: {:.4f}+-{:.4f}".format(score[0], score[1]))
def test_cganv1_cifar10(show_figure=False, block_figure_on_end=False): print("========== Test CGANv1 on CIFAR10 data ==========") np.random.seed(random_seed()) (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, monitor=[ { 'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "CGANv1 - Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, monitor=[ { 'metrics': ['x_samples'], 'title': "CGANv1 - Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = CGANv1( model_name="CGANv1_CIFAR10", num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=64, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=2, # set to 32 for a full run num_dis_feature_maps=2, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[loss_display, sample_display], num_epochs=1, # set to 100 for a full run random_state=random_seed(), verbose=1) model.fit(x_train, y_train)
def test_frechet_inception_distance(): print("========== Test Frechet Inception Distance (FID) ==========") (x_train, y_train), (_, _) = demo.load_cifar10() x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0 score = FID(data="cifar10").score(x_train[:100]) print("Case #0: FID = {:.4f}".format(score)) score = FID(data=x_train[:200]).score(x_train[:100]) print("Case #1: FID = {:.4f}".format(score))
def test_wgan_cifar10(show_figure=False, block_figure_on_end=False): print("========== Test WGAN on CIFAR10 data ==========") np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/WGAN/CIFAR10") loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"), os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")], monitor=[{'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, filepath=os.path.join(root_dir, "samples/samples_{epoch:04d}.png"), monitor=[{'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = WGAN(model_name="WGAN_CIFAR10", num_z=10, # set to 100 for a full run z_prior=Uniform1D(low=-1.0, high=1.0), img_size=(32, 32, 3), batch_size=16, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[loss_display, sample_display], num_epochs=4, # set to 100 for a full run log_path=os.path.join(root_dir, "logs"), random_state=random_seed(), verbose=1) model.fit(x_train)
def test_inception_metric_list(): print("========== Test Frechet Inception Distance (FID) ==========") (x_train, y_train), (_, _) = demo.load_cifar10() x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) * 255.0 scores = InceptionMetricList( [InceptionScore(), FID(data="cifar10"), FID(data=x_train[:200])]).score(x_train[:100]) for (i, s) in enumerate(scores): if isinstance(s, tuple): print("Case #{}: score = {:.4f}+-{:.4f}".format(i, s[0], s[1])) else: print("Case #{}: score = {:.4f}".format(i, s))
def test_image_saver_callback(): np.random.seed(random_seed()) (x_train, y_train), (_, _) = demo.load_mnist() (cifar10_train, _), (_, _) = demo.load_cifar10() imgsaver1 = ImageSaver(freq=1, filepath=os.path.join( model_dir(), "male/callbacks/imagesaver/" "mnist/mnist_{epoch:04d}.png"), monitor={ 'metrics': 'x_data', 'img_size': (28, 28, 1), 'tile_shape': (10, 10), 'images': x_train[:100].reshape([-1, 28, 28, 1]) }) imgsaver2 = ImageSaver(freq=1, filepath=os.path.join( model_dir(), "male/callbacks/imagesaver/" "cifar10/cifar10_{epoch:04d}.png"), monitor={ 'metrics': 'x_data', 'img_size': (32, 32, 3), 'tile_shape': (10, 10), 'images': cifar10_train[:100].reshape([-1, 32, 32, 3]) }) optz = SGD(learning_rate=0.001) clf = GLM(model_name="imagesaver_callback", link='softmax', loss='softmax', optimizer=optz, num_epochs=4, batch_size=100, task='classification', callbacks=[imgsaver1, imgsaver2], random_state=random_seed(), verbose=1) clf.fit(x_train, y_train)
def test_dcgan_image_saver(): print("========== Test DCGAN with Image Saver ==========") np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_mnist() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 28, 28, 1 ]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/DCGAN/imagesaver/mnist") imgsaver = ImageSaver(freq=1, filepath=os.path.join(root_dir, "mnist_{epoch:04d}.png"), monitor={ 'metrics': 'x_samples', 'num_samples': 100, 'tile_shape': (10, 10), }) model = DCGAN( model_name="DCGAN_MNIST", num_z=10, # set to 100 for a full run img_size=(28, 28, 1), batch_size=16, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[imgsaver], num_epochs=4, # set to 100 for a full run random_state=random_seed(), log_path=os.path.join(root_dir, "logs"), verbose=1) model.fit(x_train) np.random.seed(random_seed()) (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3 ]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/DCGAN/imagesaver/cifar10") imgsaver = ImageSaver(freq=1, filepath=os.path.join(root_dir, "cifar10_{epoch:04d}.png"), monitor={ 'metrics': 'x_samples', 'num_samples': 100, 'tile_shape': (10, 10), }) model = DCGAN( model_name="DCGAN_CIFAR10", num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=16, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[imgsaver], num_epochs=4, # set to 100 for a full run random_state=random_seed(), log_path=os.path.join(root_dir, "logs"), verbose=1) model.fit(x_train)
def test_dcgan_cifar10_inception_metric(show_figure=False, block_figure_on_end=False): print( "========== Test DCGAN with Inception Score and FID on CIFAR10 data ==========" ) np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3 ]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/DCGAN/cifar10") checkpoints_is = ModelCheckpoint(os.path.join( root_dir, "checkpoints_is/the_best_is.ckpt"), mode='max', monitor='inception_score', verbose=1, save_best_only=True) checkpoints_fid = ModelCheckpoint(os.path.join( root_dir, "checkpoints_fid/the_best_fid.ckpt"), mode='min', monitor='FID', verbose=1, save_best_only=True) loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "loss/loss_{epoch:04d}.png"), os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) inception_score_display = Display( layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "inception_score/" "inception_score_{epoch:04d}.png"), os.path.join(root_dir, "inception_score/" "inception_score_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['inception_score'], 'type': 'line', 'labels': ["Inception Score"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) fid_display = Display( layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "FID/" "FID_{epoch:04d}.png"), os.path.join(root_dir, "FID/" "FID_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['FID'], 'type': 'line', 'labels': ["FID"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, filepath=os.path.join( root_dir, "samples/samples_{epoch:04d}.png"), monitor=[ { 'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = DCGAN( model_name="DCGAN_CIFAR10", num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=16, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=[ 'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID' ], callbacks=[ loss_display, inception_score_display, fid_display, sample_display, checkpoints_is, checkpoints_fid ], num_epochs=4, # set to 100 for a full run inception_metrics=[InceptionScore(), FID(data="cifar10")], inception_metrics_freq=1, # summary_freq=1, # uncomment this for a full run log_path=os.path.join(root_dir, "logs"), random_state=random_seed(), verbose=1) model.fit(x_train) filepath = os.path.join(root_dir, "checkpoints_fid/the_best_fid.ckpt") print("Reloading the latest model at: {}".format(filepath)) model1 = TensorFlowModel.load_model(filepath) model1.inception_metrics = InceptionMetricList( [InceptionScore(), FID(data="cifar10")]) model1.num_epochs = 6 model1.fit(x_train) print("Done!")
def test_wgan_gp_cifar10_fid(show_figure=False, block_figure_on_end=False): print("========== Test WGAN-GP with FID on CIFAR10 data ==========") np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/WGAN-GP/CIFAR10") checkpoints = ModelCheckpoint( os.path.join(root_dir, "checkpoints/{epoch:04d}_{FID:.6f}.ckpt"), mode='min', monitor='FID', verbose=1, save_best_only=True) loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[os.path.join(root_dir, "loss/loss_{epoch:04d}.png"), os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf")], monitor=[{'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) fid_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[os.path.join(root_dir, "FID/" "FID_{epoch:04d}.png"), os.path.join(root_dir, "FID/" "FID_{epoch:04d}.pdf")], monitor=[{'metrics': ['FID'], 'type': 'line', 'labels': ["FID"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, filepath=os.path.join(root_dir, "samples/" "samples_{epoch:04d}.png"), monitor=[{'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = WGAN_GP(model_name="WGAN_GP_CIFAR10", num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=16, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss', 'FID', 'FID_100points'], callbacks=[loss_display, fid_display, sample_display, checkpoints], num_epochs=4, # set to 100 for a full run inception_metrics=[FID(data="cifar10"), FID(name="FID_100points", data=x_train[:100])], inception_metrics_freq=1, # summary_freq=1, # uncomment this for a full run log_path=os.path.join(root_dir, "logs"), random_state=random_seed(), verbose=1) model.fit(x_train)
def test_gank_image_saver(): print("========== Test GANK-Logit with Image Saver ==========") np.random.seed(random_seed()) (x_train, y_train), (x_test, y_test) = demo.load_mnist() x_train = x_train.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 28, 28, 1]) / 0.5 - 1. imgsaver = ImageSaver(freq=1, filepath=os.path.join( model_dir(), "male/GANK/imagesaver/" "mnist/mnist_{epoch:04d}.png"), monitor={ 'metrics': 'x_samples', 'num_samples': 100, 'tile_shape': (10, 10), }) model = GANK( model_name="GANK_MNIST", num_random_features=50, # set to 1000 for a full run gamma_init=0.01, loss='logit', num_z=10, # set to 100 for a full run img_size=(28, 28, 1), batch_size=32, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[imgsaver], num_epochs=4, # set to 100 for a full run random_state=random_seed(), verbose=1) model.fit(x_train) np.random.seed(random_seed()) (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. imgsaver = ImageSaver(freq=1, filepath=os.path.join( model_dir(), "male/GANK/imagesaver/" "cifar10/cifar10_{epoch:04d}.png"), monitor={ 'metrics': 'x_samples', 'num_samples': 100, 'tile_shape': (10, 10), }) model = GANK( model_name="GANK_CIFAR10", num_random_features=50, # set 1000 for a full run gamma_init=0.01, loss='logit', num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=32, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[imgsaver], num_epochs=4, # set to 500 for a full run random_state=random_seed(), verbose=1) model.fit(x_train)
def test_gank_logit_cifar10_inception_score(show_figure=False, block_figure_on_end=False): print( "========== Test GANK-Logit with Inception Score on CIFAR10 data ==========" ) np.random.seed(random_seed()) (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. filepath = os.path.join( model_dir(), "male/GANK/Logit/cifar10/checkpoints/" "{epoch:04d}_{inception_score:.6f}.ckpt") checkpoint = ModelCheckpoint(filepath, mode='max', monitor='inception_score', verbose=0, save_best_only=True) loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join( model_dir(), "male/GANK/Logit/cifar10/" "loss/loss_{epoch:04d}.png"), os.path.join( model_dir(), "male/GANK/Logit/cifar10/" "loss/loss_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) inception_score_display = Display( layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join( model_dir(), "male/GANK/Logit/cifar10/inception_score/" "inception_score_{epoch:04d}.png"), os.path.join( model_dir(), "male/GANK/Logit/cifar10/inception_score/" "inception_score_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['inception_score'], 'type': 'line', 'labels': ["Inception Score"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, filepath=os.path.join( model_dir(), "male/GANK/Logit/cifar10/samples/" "samples_{epoch:04d}.png"), monitor=[ { 'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = GANK( model_name="GANK-Logit_CIFAR10", num_random_features=50, # set 1000 for a full run gamma_init=0.01, loss='logit', num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=32, # set to 64 for a full run num_conv_layers=3, # set to 3 for a full run num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss', 'inception_score'], # callbacks=[loss_display, inception_score_display, sample_display, checkpoint], callbacks=[checkpoint], num_epochs=4, # set to 500 for a full run inception_score_freq=1, random_state=random_seed(), verbose=1) model.fit(x_train)
def test_wgan_gp_resnet_cifar10(show_figure=False, block_figure_on_end=False): print("========== Test WGAN-GP-ResNet on CIFAR10 data ==========") np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3 ]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10") loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "loss/loss_{epoch:04d}.png"), os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, monitor=[ { 'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = WGAN_GP_ResNet( model_name="WGAN_GP_ResNet_CIFAR10", num_z=10, # set to 100 for a full run img_size=(32, 32, 3), batch_size=16, # set to 64 for a full run g_blocks=('up', 'up', 'up'), d_blocks=('down', 'down', None, None), num_gen_feature_maps=4, # set to 32 for a full run num_dis_feature_maps=4, # set to 32 for a full run metrics=['d_loss', 'g_loss'], callbacks=[loss_display, sample_display], num_epochs=2, # set to 100 for a full run random_state=random_seed(), log_path=os.path.join(root_dir, "logs"), verbose=1) model.fit(x_train) print("Saving model...") save_file_path = model.save(os.path.join(root_dir, "checkpoints/ckpt")) print("Reloading model...") model1 = TensorFlowModel.load_model(save_file_path) model1.num_epochs = 4 model1.fit(x_train) print("Done!")
def test_wgan_gp_resnet_cifar10_inception_metric(show_figure=False, block_figure_on_end=False): print( "========== Test WGAN-GP-ResNet with Inception Score and FID on CIFAR10 data ==========" ) np.random.seed(random_seed()) num_data = 128 (x_train, y_train), (x_test, y_test) = demo.load_cifar10() x_train = x_train[:num_data].astype(np.float32).reshape([-1, 32, 32, 3 ]) / 0.5 - 1. x_test = x_test.astype(np.float32).reshape([-1, 32, 32, 3]) / 0.5 - 1. # uncomment for full run ''' import pickle from male.configs import data_dir tmp = pickle.load(open(os.path.join(data_dir(), "cifar10/cifar10_train.pkl"), "rb")) x_train = tmp['data'].astype(np.float32).reshape( [-1, 32, 32, 3]) / 127.5 - 1. ''' root_dir = os.path.join(model_dir(), "male/WGAN-GP-ResNet/CIFAR10") checkpoints_is = ModelCheckpoint(os.path.join( root_dir, "checkpoints_is/the_best_is.ckpt"), mode='max', monitor='inception_score', verbose=1, save_best_only=True) checkpoints_fid = ModelCheckpoint(os.path.join( root_dir, "checkpoints_fid/the_best_fid.ckpt"), mode='min', monitor='FID', verbose=1, save_best_only=True) loss_display = Display(layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "loss/loss_{epoch:04d}.png"), os.path.join(root_dir, "loss/loss_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['d_loss', 'g_loss'], 'type': 'line', 'labels': ["discriminator loss", "generator loss"], 'title': "Losses", 'xlabel': "epoch", 'ylabel': "loss", }, ]) inception_score_display = Display( layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "inception_score/" "inception_score_{epoch:04d}.png"), os.path.join(root_dir, "inception_score/" "inception_score_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['inception_score'], 'type': 'line', 'labels': ["Inception Score"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) fid_display = Display( layout=(1, 1), dpi='auto', show=show_figure, block_on_end=block_figure_on_end, filepath=[ os.path.join(root_dir, "FID/" "FID_{epoch:04d}.png"), os.path.join(root_dir, "FID/" "FID_{epoch:04d}.pdf") ], monitor=[ { 'metrics': ['FID'], 'type': 'line', 'labels': ["FID"], 'title': "Scores", 'xlabel': "epoch", 'ylabel': "score", }, ], ) sample_display = Display(layout=(1, 1), dpi='auto', figsize=(10, 10), freq=1, show=show_figure, block_on_end=block_figure_on_end, filepath=os.path.join( root_dir, "samples/samples_{epoch:04d}.png"), monitor=[ { 'metrics': ['x_samples'], 'title': "Generated data", 'type': 'img', 'num_samples': 100, 'tile_shape': (10, 10), }, ]) model = WGAN_GP_ResNet( model_name="WGAN_GP_ResNet_CIFAR10", num_z=8, # set to 128 for a full run img_size=(32, 32, 3), batch_size=64, # set to 64 for a full run g_blocks=('up', 'up', 'up'), d_blocks=('down', 'down', None, None), num_gen_feature_maps=8, # set to 128 for a full run num_dis_feature_maps=8, # set to 128 for a full run metrics=[ 'd_loss', 'g_loss', 'inception_score', 'inception_score_std', 'FID' ], callbacks=[ loss_display, inception_score_display, fid_display, sample_display, checkpoints_is, checkpoints_fid ], num_epochs=2, # set to 500 for a full run inception_metrics=[InceptionScore(), FID(data='cifar10')], inception_metrics_freq=1, num_inception_samples=100, # set to 50000 for a full run # summary_freq=1, # uncomment this for a full run log_path=os.path.join(root_dir, 'logs'), random_state=random_seed(), verbose=1) model.fit(x_train) filepath = os.path.join(root_dir, 'checkpoints_fid/the_best_fid.ckpt') print('Reloading the latest model at: {}'.format(filepath)) model1 = TensorFlowModel.load_model(filepath) model1.inception_metrics = InceptionMetricList( [InceptionScore(), FID(data='cifar10')]) model1.num_epochs = 4 model1.fit(x_train) print('Done!')