예제 #1
0
def GetGANGridSearch(dataset_name, training_steps, num_seeds):
    """Standard GAN grid search used in the paper."""

    config = {  # 48 x num_seeds workers
        "gan_type": [consts.GAN_WITH_PENALTY, consts.WGAN_WITH_PENALTY],
        "penalty_type": [consts.NO_PENALTY, consts.WGANGP_PENALTY],
        "discriminator_normalization": [
            consts.NO_NORMALIZATION, consts.SPECTRAL_NORM],
        "architecture": consts.DCGAN_ARCH,
        "dataset": dataset_name,
        "tf_seed": list(range(num_seeds)),
        "training_steps": training_steps,
        "save_checkpoint_steps": 20000,
        "batch_size": 64,
        "optimizer": "adam",
        "z_dim": 64,

        "__initial_trials": json.dumps(task_utils.CrossProduct({
            "learning_rate": 0.0001,
            "lambda": [1, 10],
            ("beta1", "beta2", "disc_iters"): [
                (0.5, 0.9, 5), (0.5, 0.999, 5), (0.9, 0.999, 5)],
        })),
    }

    return config
예제 #2
0
def CreateCrossProductAndAddDefaultParams(config):
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        defaults = GetDefaultParams(
            params.GetParameters(task["gan_type"], "wide"))
        defaults.update(task)
        task.update(defaults)
    return tasks
예제 #3
0
def RepeatExp(gan_type, num_repeat):
    """Experiment where 1 fixed gan is trained with num_repeat seeds."""
    config = {
        "gan_type": [gan_type],
        "dataset": ["mnist", "fashion-mnist"],
        "training_steps": [20 * 60000 // BATCH_SIZE],
        "save_checkpoint_steps": [5 * 60000 // BATCH_SIZE],
        "batch_size": [BATCH_SIZE],
        "z_dim": [64],
        "tf_seed": range(num_repeat)
    }
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        task.update(
            GetDefaultParams(
                params.GetParameters(task["gan_type"], task["dataset"])))
    return tasks
예제 #4
0
def TestExp():
    """Run for one epoch over all tested GANs."""
    config = {
        "gan_type": [
            "GAN", "GAN_MINMAX", "WGAN", "WGAN_GP", "DRAGAN", "VAE", "LSGAN",
            "BEGAN"
        ],
        "dataset": ["mnist"],
        "training_steps": [60000 // BATCH_SIZE],
        "save_checkpoint_steps": [10000],
        "batch_size": [BATCH_SIZE],
        "tf_seed": [42],
        "z_dim": [64],
    }
    tasks = task_utils.CrossProduct(config)
    # Add GAN and dataset specific hyperparams.
    for task in tasks:
        task.update(
            GetDefaultParams(
                params.GetParameters(task["gan_type"], task["dataset"])))
    return tasks
예제 #5
0
def GetMultiGANGridSearch(dataset_name, training_steps, num_seeds, aggregate):
    """Standard MultiGAN grid search used in the paper."""

    config = {  # 42 x num_seeds workers
        "gan_type":
        "MultiGAN",
        "penalty_type":
        consts.WGANGP_PENALTY,
        "discriminator_normalization":
        [consts.NO_NORMALIZATION, consts.SPECTRAL_NORM],
        "architecture":
        consts.DCGAN_ARCH,
        "dataset":
        dataset_name,
        "tf_seed":
        list(range(num_seeds)),
        "training_steps":
        training_steps,
        "save_checkpoint_steps":
        20000,
        "batch_size":
        64,
        "optimizer":
        "adam",

        # Model params.
        "aggregate":
        aggregate,
        "__initial_trials":
        json.dumps(
            task_utils.CrossProduct({
                ("n_blocks", "share_block_weights", "n_heads", "k"): [
                    (0, False, 0, 3),
                    (0, False, 0, 4),  # M-GAN [3, 4]
                    (0, False, 0, 5),  # M-GAN  5
                    (1, False, 1, 3),
                    (1, False, 2, 3),  # RM-GAN 3
                    (2, False, 1, 3),
                    (2, True, 1, 3),  # RM-GAN 3
                    (2, False, 2, 3),
                    (2, True, 2, 3),  # RM-GAN 3
                    (1, False, 1, 4),
                    (1, False, 2, 4),  # RM-GAN 4
                    (2, False, 1, 4),
                    (2, True, 1, 4),  # RM-GAN 4
                    (2, False, 2, 4),
                    (2, True, 2, 4),  # RM-GAN 4
                    (1, False, 1, 5),
                    (1, False, 2, 5),  # RM-GAN 5
                    (2, False, 1, 5),
                    (2, True, 1, 5),  # RM-GAN 5
                    (2, False, 2, 5),
                    (2, True, 2, 5),  # RM-GAN 5
                ],
                ("z_dim", "embedding_dim"): [(64, 32)],
                "learning_rate":
                0.0001,
                "lambda":
                1,
                ("beta1", "beta2", "disc_iters"): [(0.9, 0.999, 5)],
            }))
    }

    return config
예제 #6
0
def GetMetaTasks(experiment_name):
    """Returns meta options to be used for study generation.

    Args:
      experiment_name: name of an experiment

    Raises:
      ValueError: When experiment is not found.
    """

    if experiment_name == "multi_gan-debug":
        meta_config = {
            "gan_type": [
                "MultiGAN",
            ],
            "penalty_type": [consts.WGANGP_PENALTY],
            "discriminator_normalization": [consts.SPECTRAL_NORM],
            "architecture": consts.DCGAN_ARCH,
            "dataset": ["multi-mnist-3-uniform"],
            "sampler": [
                "rs",
            ],
            "tf_seed": [0],
            "training_steps": [1000],
            "save_checkpoint_steps": [100],
            "batch_size": [64],
            "optimizer": ["adam"],
            "learning_rate": [0.0001],
            "lambda": [10],
            "beta1": [0.5],
            "beta2": [0.9],
            "disc_iters": [5],

            # Model params.
            "k": [3],
            "aggregate": ["sum_clip"],
            "n_heads": 4,
            "n_blocks": 2,
            "share_block_weights": True,
            "embedding_dim": 32,
        }

    #######################
    #  PAPER EXPERIMENTS  #
    #######################

    # MULTI-MNIST

    # 240
    elif experiment_name == "gan-base-experiment-paper":
        meta_config = GetGANGridSearch(dataset_name="multi-mnist-3-uniform",
                                       training_steps=1000000,
                                       num_seeds=5)

    # 210
    elif experiment_name == "multi_gan-base-experiment-paper":
        meta_config = GetMultiGANGridSearch(
            dataset_name="multi-mnist-3-uniform",
            training_steps=1000000,
            num_seeds=5,
            aggregate="sum_clip")

    # 480
    elif experiment_name == "gan-relational-experiment-paper":
        meta_config = GetGANGridSearch(dataset_name=[
            "multi-mnist-3-triplet", "multi-mnist-3-uniform-rgb-occluded"
        ],
                                       training_steps=1000000,
                                       num_seeds=5)

    # 210
    elif experiment_name == "multi_gan-relational-experiment-triplet-paper":
        meta_config = GetMultiGANGridSearch(
            dataset_name="multi-mnist-3-triplet",
            training_steps=1000000,
            num_seeds=5,
            aggregate="sum_clip")

    # 210
    elif experiment_name == "multi_gan-relational-experiment-rgb-occluded-paper":
        meta_config = GetMultiGANGridSearch(
            dataset_name="multi-mnist-3-uniform-rgb-occluded",
            training_steps=1000000,
            num_seeds=5,
            aggregate="implicit_alpha")

    # CIFAR 10

    # 240
    elif experiment_name == "gan-background-experiment-paper":
        meta_config = GetGANGridSearch(
            dataset_name="multi-mnist-3-uniform-rgb-occluded-cifar10",
            training_steps=1000000,
            num_seeds=5)

    # 210
    elif experiment_name == "multi_gan-background-experiment-paper":
        meta_config = GetMultiGANGridSearchKPlusOne(
            dataset_name="multi-mnist-3-uniform-rgb-occluded-cifar10",
            training_steps=1000000,
            num_seeds=5,
            aggregate="alpha")

    # 420
    elif experiment_name == "multi_gan_bg-background-experiment-paper":
        meta_config = GetMultiGANBackgroundGridSearch(
            dataset_name="multi-mnist-3-uniform-rgb-occluded-cifar10",
            training_steps=1000000,
            num_seeds=5,
            aggregate="alpha")

    # CLEVR

    # 240
    elif experiment_name == "gan-clevr-experiment-paper":
        meta_config = GetGANGridSearch(dataset_name="clevr",
                                       training_steps=1000000,
                                       num_seeds=5)

    # 210
    elif experiment_name == "multi_gan-clevr-experiment-paper":
        meta_config = GetMultiGANGridSearchKPlusOne(dataset_name="clevr",
                                                    training_steps=1000000,
                                                    num_seeds=5,
                                                    aggregate="alpha")

    # 420
    elif experiment_name == "multi_gan_bg-clevr-experiment-paper":
        meta_config = GetMultiGANBackgroundGridSearch(dataset_name="clevr",
                                                      training_steps=1000000,
                                                      num_seeds=5,
                                                      aggregate="alpha")

    # 200
    elif experiment_name == "multi_gan_bg-clevr-rs200k-paper":
        meta_config = GetMultiGANBackgroundRandomSearch(dataset_name="clevr",
                                                        training_steps=200000,
                                                        aggregate="alpha",
                                                        num_tasks=200)

    else:
        raise ValueError("Unknown study-based experiment %s." %
                         experiment_name)

    options = task_utils.CrossProduct(meta_config)
    return options