Beispiel #1
0
def setup_config(run_dir, **args):
    args = EasyDict(args)  # command-line options
    train = EasyDict(run_dir=run_dir)  # training loop options
    vis = EasyDict(run_dir=run_dir)  # visualization loop options

    if args.reload:
        config_fn = os.path.join(run_dir, "training_options.json")
        if os.path.exists(config_fn):
            # Load config form the experiment existing file (and so ignore command-line arguments)
            with open(config_fn, "rt") as f:
                config = json.load(f)
            return config
        misc.log(
            f"Warning: --reload is set for a new experiment {args.expname}," +
            f" but configuration file to reload from {config_fn} doesn't exist.",
            "red")

    # GANformer and baselines default settings
    # ----------------------------------------------------------------------------

    if args.ganformer_default:
        task = args.dataset
        nset(args, "mirror_augment", task in ["cityscapes", "ffhq"])

        nset(args, "transformer", True)
        nset(args, "components_num", {"clevr": 8}.get(task, 16))
        nset(args, "latent_size", {"clevr": 128}.get(task, 512))

        nset(args, "normalize", "layer")
        nset(args, "integration", "mul")
        nset(args, "kmeans", True)
        nset(args, "use_pos", True)
        nset(args, "mapping_ltnt2ltnt", task != "clevr")
        nset(args, "style", task != "clevr")

        nset(args, "g_arch", "resnet")
        nset(args, "mapping_resnet", True)

        gammas = {"ffhq": 10, "cityscapes": 20, "clevr": 40, "bedrooms": 100}
        nset(args, "gamma", gammas.get(task, 10))

    if args.baseline == "GAN":
        nset(args, "style", False)
        nset(args, "latent_stem", True)

    ## k-GAN and SAGAN  are not currently supported in the pytorch version.
    ## See the TF version for implementation of these baselines!
    # if args.baseline == "SAGAN":
    #     nset(args, "style", False)
    #     nset(args, "latent_stem", True)
    #     nset(args, "g_img2img", 5)

    # if args.baseline == "kGAN":
    #     nset(args, "kgan", True)
    #     nset(args, "merge_layer", 5)
    #     nset(args, "merge_type", "softmax")
    #     nset(args, "components_num", 8)

    # General setup
    # ----------------------------------------------------------------------------

    # If the flag is specified without arguments (--arg), set to True
    for arg in [
            "cuda_bench", "allow_tf32", "keep_samples", "style", "local_noise"
    ]:
        if args[arg] is None:
            args[arg] = True

    if not any([args.train, args.eval, args.vis]):
        misc.log(
            "Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes",
            "red")
    for arg in ["train", "eval", "vis", "last_snapshots"]:
        cset(train, arg, args[arg])

    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    if not (num_gpus >= 1 and num_gpus & (num_gpus - 1) == 0):
        misc.error("Number of GPUs must be a power of two")
    args.num_gpus = num_gpus

    # CUDA settings
    for arg in ["batch_size", "batch_gpu", "allow_tf32"]:
        cset(train, arg, args[arg])
    cset(train, "cudnn_benchmark", args.cuda_bench)

    # Data setup
    # ----------------------------------------------------------------------------

    # For bedrooms, we choose the most common ratio in the
    # dataset and crop the other images into that ratio.
    ratios = {
        "clevr": 0.75,
        "bedrooms": 188 / 256,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = args.ratio or ratios.get(args.dataset, 1.0)
    args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio < 0.5 else None

    args.printname = args.expname
    for arg in ["total_kimg", "printname"]:
        cset(train, arg, args[arg])

    dataset_args = EasyDict(class_name="training.dataset.ImageFolderDataset",
                            path=f"{args.data_dir}/{args.dataset}",
                            max_items=args.train_images_num,
                            resolution=args.resolution,
                            ratio=args.ratio,
                            mirror_augment=args.mirror_augment)
    dataset_args.loader_args = EasyDict(num_workers=args.num_threads,
                                        pin_memory=True,
                                        prefetch_factor=2)

    # Optimization setup
    # ----------------------------------------------------------------------------

    cG = set_net("Generator", ["mapping", "synthesis"], args.g_lr, 4)
    cD = set_net("Discriminator", ["mapping", "block", "epilogue"], args.d_lr,
                 16)
    cset([cG, cD], "crop_ratio", args.crop_ratio)

    mbstd = min(
        args.batch_gpu, 4
    )  # other hyperparams behave more predictably if mbstd group size remains fixed
    cset(cD.epilogue_kwargs, "mbstd_group_size", mbstd)

    # Automatic tuning
    if args.autotune:
        batch_size = max(
            min(args.num_gpus * min(4096 // args.resolution, 32), 64),
            args.num_gpus)  # keep gpu memory consumption at bay
        batch_gpu = args.batch_size // args.num_gpus
        nset(args, "batch_size", batch_size)
        nset(args, "batch_gpu", batch_gpu)

        fmap_decay = 1 if args.resolution >= 512 else 0.5  # other hyperparams behave more predictably if mbstd group size remains fixed
        lr = 0.002 if args.resolution >= 1024 else 0.0025
        gamma = 0.0002 * (args.resolution**
                          2) / args.batch_size  # heuristic formula

        cset([cG.synthesis_kwargs, cD], "dim_base", int(fmap_decay * 32768))
        nset(args, "g_lr", lr)
        cset(cG.opt_args, "lr", args.g_lr)
        nset(args, "d_lr", lr)
        cset(cD.opt_args, "lr", args.d_lr)
        nset(args, "gamma", gamma)

        train.ema_rampup = 0.05
        train.ema_kimg = batch_size * 10 / 32

    if args.batch_size % (args.batch_gpu * args.num_gpus) != 0:
        misc.error(
            "--batch-size should be divided by --batch-gpu * 'num_gpus'")

    # Loss and regularization settings
    loss_args = EasyDict(class_name="training.loss.StyleGAN2Loss",
                         g_loss=args.g_loss,
                         d_loss=args.d_loss,
                         r1_gamma=args.gamma,
                         pl_weight=args.pl_weight)

    # if args.fp16:
    #     cset([cG.synthesis_kwargs, cD], "num_fp16_layers", 4) # enable mixed-precision training
    #     cset([cG.synthesis_kwargs, cD], "conv_clamp", 256) # clamp activations to avoid float16 overflow

    # cset([cG.synthesis_kwargs, cD.block_args], "fp16_channels_last", args.nhwc)

    # Evaluation and visualization
    # ----------------------------------------------------------------------------

    from metrics import metric_main
    for metric in args.metrics:
        if not metric_main.is_valid_metric(metric):
            misc.error(
                f"Unknown metric: {metric}. The valid metrics are: {metric_main.list_valid_metrics()}"
            )

    for arg in ["num_gpus", "metrics", "eval_images_num", "truncation_psi"]:
        cset(train, arg, args[arg])
    for arg in ["keep_samples", "num_heads"]:
        cset(vis, arg, args[arg])

    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = list({arg for arg in vis_types if args[f"vis_{arg}"]})

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "interpolation_density",
        # "intrp_per_component": "interpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks setup
    # ----------------------------------------------------------------------------

    # Networks architecture
    cset(cG.synthesis_kwargs, "architecture", args.g_arch)
    cset(cD, "architecture", args.d_arch)

    # Latent sizes
    if args.components_num > 0:
        if not args.transformer:  # or args.kgan):
            misc.error(
                "--components-num > 0 but the model is not using components. "
                +
                "Add --transformer for GANformer (which uses latent components)."
            )
        if args.latent_size % args.components_num != 0:
            misc.error(
                f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})"
            )
        args.latent_size = int(args.latent_size / args.components_num)

    cG.z_dim = cG.w_dim = args.latent_size
    cset([cG, vis], "k", args.components_num +
         1)  # We add a component to modulate features globally

    # Mapping network
    args.mapping_layer_dim = args.mapping_dim
    for arg in ["num_layers", "layer_dim", "resnet", "shared", "ltnt2ltnt"]:
        field = f"mapping_{arg}"
        cset(cG.mapping_kwargs, arg, args[field])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "local_noise"]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # GANformer
    cset([cG.synthesis_kwargs, cG.mapping_kwargs], "transformer",
         args.transformer)

    # Attention related settings
    for arg in ["use_pos", "num_heads", "ltnt_gate", "attention_dropout"]:
        cset([cG.mapping_kwargs, cG.synthesis_kwargs], arg, args[arg])

    # Attention types and layers
    for arg in ["start_res", "end_res"
                ]:  # , "local_attention" , "ltnt2ltnt", "img2img", "img2ltnt"
        cset(cG.synthesis_kwargs, arg, args[f"g_{arg}"])

    # Mixing and dropout
    for arg in ["style_mixing", "component_mixing"]:
        cset(loss_args, arg, args[arg])
    cset(cG, "component_dropout", args["component_dropout"])

    # Extra transformer options
    args.norm = args.normalize
    for arg in [
            "norm", "integration", "img_gate", "iterative", "kmeans",
            "kmeans_iters"
    ]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # Positional encoding
    # args.pos_dim = args.pos_dim or args.latent_size
    for arg in ["dim", "type", "init", "directions_num"]:
        field = f"pos_{arg}"
        cset(cG.synthesis_kwargs, field, args[field])

    # k-GAN
    # for arg in ["layer", "type", "same"]:
    #     field = "merge_{}".format(arg)
    #     cset(cG.args, field, args[field])
    # cset(cG.synthesis_kwargs, "merge", args.kgan)
    # if args.kgan and args.transformer:
    # misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both")

    config = EasyDict(train)
    config.update(cG=cG,
                  cD=cD,
                  loss_args=loss_args,
                  dataset_args=dataset_args,
                  vis_args=vis)

    # Save config file
    with open(os.path.join(run_dir, "training_options.json"), "wt") as f:
        json.dump(config, f, indent=2)

    return config