def __init__(self, z_dim, batch_size): """.""" logger.info("MNIST GAN CNN") logger.info("z_dim: {}".format(z_dim)) mnist_data_dir = get_project_directory("mnist", "dataset") logger.info("mnist_data_dir: {}".format(mnist_data_dir)) self.mnist_data = input_data.read_data_sets(mnist_data_dir, one_hot=True) self.z_dim = z_dim self.z_mean = np.zeros(self.z_dim) self.z_cov = np.diag(np.ones(self.z_dim)) self.x_dim = 784 self.y_dim = 10 self.batch_size = batch_size self.G_input_dim = self.y_dim self.D_input_dim = self.x_dim + self.y_dim self.real_prob_val = 0.9 self.fake_prob_val = 0.1 self.plot_batch_x, self.plot_batch_y = self.get_data_batches( 10 * 10, normalize="no") self.plot_batch_z = self.get_noise_batches(10 * 10, random_seed=1) self.std_test_data = False self.__build_model() self.__start_session()
def __init__(self): """.""" mnist_data_dir = get_project_directory("mnist", "dataset") self.mnist_data = input_data.read_data_sets(mnist_data_dir, one_hot=True) self.__build_model() self.__build_accuracy_computation() self.__start_session()
def plot_fake_data(self, epoch, batch_fx, fig_plots=(10, 10), figsize=(10, 10)): """.""" batch_fx = batch_fx.reshape(-1, 28, 28) plt.figure(figsize=figsize) for i in range(batch_fx.shape[0]): plt.subplot(fig_plots[0], fig_plots[1], i + 1) plt.imshow(batch_fx[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() img_dir = get_project_directory("mnist", "results", self.start_timestamp) plt.savefig(os.path.join(img_dir, "mnist_gan_mlp_{}.png".format(epoch))) plt.close()
self.std_batch_z = self.get_noise_batches(batch_size=10 * 10) self.std_test_data = True loss_G, loss_D, D1, D2, batch_fx = self.sess.run( [self.loss_G, self.loss_D, self.D1, self.D2, self.G], feed_dict={ self.x: self.std_batch_x, self.y: self.std_batch_y, self.z: self.std_batch_z, self.keep_prob: 1.0 }) self.plot_losses(epoch=i, loss_G=loss_G, loss_D=loss_D) self.plot_fake_data(epoch=i, batch_fx=batch_fx) time_diff = time.time() - start_time start_time = time.time() logger.info( "Epoch: {:3d} - L_G: {:0.3f} - L_D: {:0.3f} - D1: {:0.3f} - D2: {:0.3f} - Time: {:0.1f}" .format(i, loss_G, loss_D, D1[0][0], D2[0][0], time_diff)) if __name__ == "__main__": setup_logger(log_directory=get_project_directory("mnist_cnn", "logs"), file_handler_type=HandlerType.TIME_ROTATING_FILE_HANDLER, allow_console_logging=True, allow_file_logging=True, max_file_size_bytes=10000, change_log_level=None) mnist_gan_mlp = GAN_CNN(z_dim=10, batch_size=100) mnist_gan_mlp.run(epochs=1000, batch_size=100, summary_epochs=1)