def GetAllTrainingParams(): all_params = set(["architecture"]) for gan_type in SUPPORTED_GANS: for _ in ["mnist", "fashion-mnist", "cifar10", "celeba"]: p = params.GetParameters(gan_type, "wide") all_params.update(p.keys()) logging.info("All training parameter exported: %s", sorted(all_params)) return sorted(all_params)
def CreateCrossProductAndAddDefaultParams(config): tasks = task_utils.CrossProduct(config) # Add GAN and dataset specific hyperparams. for task in tasks: defaults = GetDefaultParams( params.GetParameters(task["gan_type"], "wide")) defaults.update(task) task.update(defaults) return tasks
def RepeatExp(gan_type, num_repeat): """Experiment where 1 fixed gan is trained with num_repeat seeds.""" config = { "gan_type": [gan_type], "dataset": ["mnist", "fashion-mnist"], "training_steps": [20 * 60000 // BATCH_SIZE], "save_checkpoint_steps": [5 * 60000 // BATCH_SIZE], "batch_size": [BATCH_SIZE], "z_dim": [64], "tf_seed": range(num_repeat) } tasks = task_utils.CrossProduct(config) # Add GAN and dataset specific hyperparams. for task in tasks: task.update( GetDefaultParams( params.GetParameters(task["gan_type"], task["dataset"]))) return tasks
def TestExp(): """Run for one epoch over all tested GANs.""" config = { "gan_type": [ "GAN", "GAN_MINMAX", "WGAN", "WGAN_GP", "DRAGAN", "VAE", "LSGAN", "BEGAN" ], "dataset": ["mnist"], "training_steps": [60000 // BATCH_SIZE], "save_checkpoint_steps": [10000], "batch_size": [BATCH_SIZE], "tf_seed": [42], "z_dim": [64], } tasks = task_utils.CrossProduct(config) # Add GAN and dataset specific hyperparams. for task in tasks: task.update( GetDefaultParams( params.GetParameters(task["gan_type"], task["dataset"]))) return tasks
def TuneParams(gan_types, dataset_name, num_samples, num_repeat, range_type="wide"): """Create a set of tasks for optimizing model score.""" all_tasks = [] for gan_type in gan_types: gan_params = params.GetParameters(gan_type, dataset_name, range_type) # Always include default params. samples = [GetDefaultParams(gan_params)] for _ in range(num_samples - 1): samples.append(GetSample(gan_params)) for idx, sample in enumerate(samples): for i in range(num_repeat): s = copy.deepcopy(sample) s["gan_type"] = gan_type s["sample_id"] = idx # For joins in dremel. s["dataset"] = dataset_name s["tf_seed"] = i if dataset_name == "cifar10": # 100 epochs for CIFAR100 epoch = 50000 s["training_steps"] = 100 * epoch // BATCH_SIZE elif dataset_name == "celeba": # 40 epochs for CelebA epoch = 162000 s["training_steps"] = 40 * epoch // BATCH_SIZE else: # 20 epochs for MNIST/FASHION-MNIST epoch = 50000 s["training_steps"] = 20 * epoch // BATCH_SIZE s["save_checkpoint_steps"] = 5 * epoch // BATCH_SIZE all_tasks.append(s) all_tasks = [collections.OrderedDict(sorted(x.items())) for x in all_tasks] return all_tasks
def testParameterRanges(self): training_parameters = params.GetParameters("WGAN", "wide") self.assertEqual(len(list(training_parameters.keys())), 5) training_parameters = params.GetParameters("BEGAN", "wide") self.assertEqual(len(list(training_parameters.keys())), 6)