def get_dataset(input_params, dataset_type: problem_type.ProblemType): if dataset_type == problem_type.ProblemType.VANILLA_MNIST.name: return mnist.MnistDataset(input_params) elif dataset_type == problem_type.ProblemType.VANILLA_FASHION_MNIST.name: return fashion_mnist.FashionMnistDataset(input_params) elif dataset_type == problem_type.ProblemType.VANILLA_CIFAR10.name: return cifar10.Cifar10Dataset(input_params) elif dataset_type == problem_type.ProblemType.CONDITIONAL_MNIST.name: return mnist.MnistDataset(input_params, with_labels=True) elif dataset_type == problem_type.ProblemType.CONDITIONAL_FASHION_MNIST.name: return fashion_mnist.FashionMnistDataset(input_params, with_labels=True) elif dataset_type == problem_type.ProblemType.CONDITIONAL_CIFAR10.name: return cifar10.Cifar10Dataset(input_params, with_labels=True) elif dataset_type == problem_type.ProblemType.CYCLE_SUMMER2WINTER.name: return summer2winter.SummerToWinterDataset(input_params) else: raise NotImplementedError
from gans.trainers import vanilla_gan_trainer model_parameters = edict({ 'img_height': 28, 'img_width': 28, 'num_channels': 1, 'batch_size': 16, 'num_epochs': 10, 'buffer_size': 1000, 'latent_size': 100, 'learning_rate_generator': 0.0001, 'learning_rate_discriminator': 0.0001, 'save_images_every_n_steps': 10 }) dataset = mnist.MnistDataset(model_parameters) def validation_dataset(): return tf.random.normal( [model_parameters.batch_size, model_parameters.latent_size]) validation_dataset = validation_dataset() generator = sequential.SequentialModel(layers=[ keras.Input(shape=[model_parameters.latent_size]), layers.Dense(units=7 * 7 * 256, use_bias=False), layers.BatchNormalization(), layers.LeakyReLU(), layers.Reshape((7, 7, 256)),
model_parameters = edict({ 'img_height': 28, 'img_width': 28, 'num_channels': 1, 'batch_size': 16, 'num_epochs': 10, 'buffer_size': 1000, 'latent_size': 100, 'num_classes': 10, 'learning_rate_generator': 0.0001, 'learning_rate_discriminator': 0.0001, 'save_images_every_n_steps': 10 }) dataset = mnist.MnistDataset(model_parameters, with_labels=True) def validation_dataset(): test_batch_size = model_parameters.num_classes ** 2 labels = np.repeat(list(range(model_parameters.num_classes)), model_parameters.num_classes) validation_samples = [tf.random.normal([test_batch_size, model_parameters.latent_size]), np.array(labels)] return validation_samples validation_dataset = validation_dataset() generator = latent_to_image.LatentToImageGenerator(model_parameters) discriminator = discriminator.Discriminator(model_parameters) generator_optimizer = optimizers.Adam(