コード例 #1
0
ファイル: eval_gan_lib.py プロジェクト: xiao7199/compare_gan
def GetAllTrainingParams():
    all_params = set(["architecture"])
    for gan_type in SUPPORTED_GANS:
        for _ in ["mnist", "fashion-mnist", "cifar10", "celeba"]:
            p = params.GetParameters(gan_type, "wide")
            all_params.update(p.keys())
    logging.info("All training parameter exported: %s", sorted(all_params))
    return sorted(all_params)
コード例 #2
0
def CreateCrossProductAndAddDefaultParams(config):
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        defaults = GetDefaultParams(
            params.GetParameters(task["gan_type"], "wide"))
        defaults.update(task)
        task.update(defaults)
    return tasks
コード例 #3
0
def RepeatExp(gan_type, num_repeat):
    """Experiment where 1 fixed gan is trained with num_repeat seeds."""
    config = {
        "gan_type": [gan_type],
        "dataset": ["mnist", "fashion-mnist"],
        "training_steps": [20 * 60000 // BATCH_SIZE],
        "save_checkpoint_steps": [5 * 60000 // BATCH_SIZE],
        "batch_size": [BATCH_SIZE],
        "z_dim": [64],
        "tf_seed": range(num_repeat)
    }
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        task.update(
            GetDefaultParams(
                params.GetParameters(task["gan_type"], task["dataset"])))
    return tasks
コード例 #4
0
def TestExp():
    """Run for one epoch over all tested GANs."""
    config = {
        "gan_type": [
            "GAN", "GAN_MINMAX", "WGAN", "WGAN_GP", "DRAGAN", "VAE", "LSGAN",
            "BEGAN"
        ],
        "dataset": ["mnist"],
        "training_steps": [60000 // BATCH_SIZE],
        "save_checkpoint_steps": [10000],
        "batch_size": [BATCH_SIZE],
        "tf_seed": [42],
        "z_dim": [64],
    }
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        task.update(
            GetDefaultParams(
                params.GetParameters(task["gan_type"], task["dataset"])))
    return tasks
コード例 #5
0
def TuneParams(gan_types,
               dataset_name,
               num_samples,
               num_repeat,
               range_type="wide"):
    """Create a set of tasks for optimizing model score."""
    all_tasks = []
    for gan_type in gan_types:
        gan_params = params.GetParameters(gan_type, dataset_name, range_type)
        # Always include default params.
        samples = [GetDefaultParams(gan_params)]
        for _ in range(num_samples - 1):
            samples.append(GetSample(gan_params))

        for idx, sample in enumerate(samples):
            for i in range(num_repeat):
                s = copy.deepcopy(sample)
                s["gan_type"] = gan_type
                s["sample_id"] = idx  # For joins in dremel.
                s["dataset"] = dataset_name
                s["tf_seed"] = i
                if dataset_name == "cifar10":
                    # 100 epochs for CIFAR100
                    epoch = 50000
                    s["training_steps"] = 100 * epoch // BATCH_SIZE
                elif dataset_name == "celeba":
                    # 40 epochs for CelebA
                    epoch = 162000
                    s["training_steps"] = 40 * epoch // BATCH_SIZE
                else:
                    # 20 epochs for MNIST/FASHION-MNIST
                    epoch = 50000
                    s["training_steps"] = 20 * epoch // BATCH_SIZE
                s["save_checkpoint_steps"] = 5 * epoch // BATCH_SIZE
                all_tasks.append(s)

    all_tasks = [collections.OrderedDict(sorted(x.items())) for x in all_tasks]
    return all_tasks
コード例 #6
0
    def testParameterRanges(self):
        training_parameters = params.GetParameters("WGAN", "wide")
        self.assertEqual(len(list(training_parameters.keys())), 5)

        training_parameters = params.GetParameters("BEGAN", "wide")
        self.assertEqual(len(list(training_parameters.keys())), 6)