Example #1
0
def setup_training_stages(loss_args, G, cG, D, cD, ddp_nets, device, log):
    misc.log("Setting up training stages...", "white", log)
    loss = dnnlib.util.construct_class_by_name(device = device, **ddp_nets, **loss_args) # subclass of training.loss.Loss
    stages = []

    for name, net, config in [("G", G, cG), ("D", D, cD)]:
        if config.reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(params = net.parameters(), **config.opt_args) # subclass of torch.optimOptimizer
            stages.append(dnnlib.EasyDict(name = f"{name}_both", net = net, opt = opt, interval = 1))
        # Lazy regularization
        else: 
            mb_ratio = config.reg_interval / (config.reg_interval + 1)
            opt_args = dnnlib.EasyDict(config.opt_args)
            opt_args.lr = opt_args.lr * mb_ratio
            opt_args.betas = [beta ** mb_ratio for beta in opt_args.betas]
            opt = dnnlib.util.construct_class_by_name(net.parameters(), **opt_args) # subclass of torch.optimOptimizer
            stages.append(dnnlib.EasyDict(name = f"{name}_main", net = net, opt = opt, interval = 1))
            stages.append(dnnlib.EasyDict(name = f"{name}_reg", net = net, opt = opt, interval = config.reg_interval))

    for stage in stages:
        stage.start_event = None
        stage.end_event = None
        if log:
            stage.start_event = torch.cuda.Event(enable_timing = True)
            stage.end_event = torch.cuda.Event(enable_timing = True)

    return loss, stages
Example #2
0
def construct_nets(cG, cD, dataset, device, log):
    misc.log("Constructing networks...", "white", log)
    common_kwargs = dict(c_dim = dataset.label_dim, img_resolution = dataset.resolution, img_channels = dataset.num_channels)
    G = dnnlib.util.construct_class_by_name(**cG, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nnnet
    D = dnnlib.util.construct_class_by_name(**cD, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nnnet
    Gs = copy.deepcopy(G).eval()
    return G, D, Gs
Example #3
0
def load_nets(resume_pkl, lG, lD, lGs, recompile):
    misc.log("Loading networks from %s..." % resume_pkl, "white")
    rG, rD, rGs = pretrained_networks.load_networks(resume_pkl)
    
    if recompile:
        misc.log("Copying nets...")
        lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs)
    else:
        lG, lD, lGs = rG, rD, rGs
    return lG, lD, lGs
Example #4
0
def distribute_nets(G, D, Gs, device, num_gpus, log):
    misc.log(f"Distributing across {num_gpus} GPUs...", "white", log)
    networks = {}
    for name, net in [("G", G), ("D", D), (None, Gs)]: # ("G_mapping", G.mapping), ("G_synthesis", G.synthesis)
        if (num_gpus > 1) and (net is not None) and len(list(net.parameters())) != 0:
            net.requires_grad_(True)
            net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [device], broadcast_buffers = False, 
                find_unused_parameters = True)
            net.requires_grad_(False)
        if name is not None:
            networks[name] = net
    return networks
Example #5
0
def load_nets(load_pkl, nets, device, log):
    if (load_pkl is not None) and log:
        misc.log(f"Resuming from {load_pkl}", "white", log)
        resume_data = loader.load_network(load_pkl)

        if nets is not None:
            G, D, Gs = nets
            for name, net in [("G", G), ("D", D), ("Gs", Gs)]:
                torch_misc.copy_params_and_buffers(resume_data[name], net, require_all = False)
        else:
            for net in ["G", "D", "Gs"]:
                resume_data[net] = copy.deepcopy(resume_data[net]).eval().requires_grad_(False).to(device)
            nets = (resume_data["G"], resume_data["D"], resume_data["Gs"])

    return nets
Example #6
0
def load_dataset(dataset_args, batch_size, rank, num_gpus, log):
    misc.log("Loading training set...", "white", log)
    dataset = dnnlib.util.construct_class_by_name(**dataset_args) # subclass of training.datasetDataset
    dataset_sampler = torch_misc.InfiniteSampler(dataset = dataset, rank = rank, num_replicas = num_gpus)
    dataset_iter = iter(torch.utils.data.DataLoader(dataset = dataset, sampler = dataset_sampler, 
        batch_size = batch_size//num_gpus, **dataset_args.loader_args))
    misc.log(f"Num images: {misc.bcolored(len(dataset), 'blue')}", log = log)
    misc.log(f"Image shape: {misc.bcolored(dataset.image_shape, 'blue')}", log = log)
    misc.log(f"Label shape: {misc.bcolored(dataset.label_shape, 'blue')}", log = log)
    return dataset, dataset_iter
Example #7
0
def run_cmdline(argv):
    parser = argparse.ArgumentParser(prog = argv[0], description = "Download and prepare data for the GANformer.")
    parser.add_argument("--data-dir",       help = "Directory of created dataset", default = "datasets", type = str)
    parser.add_argument("--max-images",     help = "Maximum number of images to have in the dataset (optional).", default = None, type = int)
    # Default tasks
    parser.add_argument("--clevr",          help = "Prepare the CLEVR dataset (6.41GB download, 100k images)", dest = "tasks", action = "append_const", const = "clevr")
    parser.add_argument("--bedrooms",       help = "Prepare the LSUN-bedrooms dataset (42.8GB download, 3M images)", dest = "tasks", action = "append_const", const = "bedrooms")
    parser.add_argument("--ffhq",           help = "Prepare the FFHQ dataset (13GB download, 70k images)", dest = "tasks", action = "append_const", const = "ffhq")
    parser.add_argument("--cityscapes",     help = "Prepare the cityscapes dataset (1.8GB download, 25k images)", dest = "tasks", action = "append_const", const = "cityscapes")
    # Create a new task with custom images
    parser.add_argument("--task",           help = "New dataset name", type = str, dest = "tasks", action = "append")
    parser.add_argument("--images-dir",     help = "Provide source image directory/file to convert into png-directory dataset (saves varied image resolutions)", default = None, type = str)
    parser.add_argument("--format",         help = "Images format", default = None, choices = ["png", "jpg", "npy", "hdf5", "tfds", "lmdb", "tfrecords"], type = str)
    parser.add_argument("--ratio",          help = "Images height/width", default = 1.0, type = float)

    args = parser.parse_args()
    if not args.tasks:
        misc.error("No tasks specified. Please see '-h' for help.")
    if args.max_images is not None and args.max_images < 50000:
        misc.log(f"Warning: max-images is set to {args.max_images}. We recommend setting it at least to 50,000 to allow statistically correct computation of the FID-50k metric.", "red")

    prepare(**vars(args))
Example #8
0
def setup_savefile(args, run_name, run_dir, config):
    snapshot, kimg, resume = None, 0, False
    pkls = sorted(glob.glob(f"{run_dir}/network*.pkl"))
    # Load a particular snapshot is specified
    if args.pretrained_pkl is not None and args.pretrained_pkl != "None":
        # Soft links support
        if args.pretrained_pkl.startswith("gdrive"):
            if args.pretrained_pkl not in loader.pretrained_networks:
                misc.error(
                    "--pretrained_pkl {} not available in the catalog (see loader.pretrained_networks dict)"
                )

            snapshot = args.pretrained_pkl
        else:
            snapshot = glob.glob(args.pretrained_pkl)[0]
            if os.path.islink(snapshot):
                snapshot = os.readlink(snapshot)

        # Extract training step from the snapshot if specified
        try:
            kimg = int(snapshot.split("-")[-1].split(".")[0])
        except:
            pass

    # Find latest snapshot in the directory
    elif len(pkls) > 0:
        snapshot = pkls[-1]
        kimg = int(snapshot.split("-")[-1].split(".")[0])
        resume = True

    if snapshot:
        misc.log(f"Resuming {run_name}, from {snapshot}, kimg {kimg}", "white")
        config.resume_pkl = snapshot
        config.resume_kimg = kimg
    else:
        misc.log("Start model training from scratch", "white")
Example #9
0
def run(**args):
    args = EasyDict(args)
    train = EasyDict(run_func_name="training.training_loop.training_loop"
                     )  # training loop options
    sched = EasyDict()  # TrainingSchedule options
    vis = EasyDict()  # visualize.eval() options
    grid = EasyDict(size="1080p",
                    layout="random")  # setup_snapshot_img_grid() options
    sc = dnnlib.SubmitConfig()  # dnnlib.submit_run() options

    # If the flag is specified without arguments (--arg), set to True
    for arg in [
            "summarize", "keep_samples", "style", "fused_modconv",
            "local_noise"
    ]:
        if args[arg] is None:
            args[arg] = True

    if not args.train and not args.eval:
        misc.log(
            "Warning: Neither --train nor --eval are provided. Therefore, we only print network shapes",
            "red")

    if args.gansformer_default:
        task = args.dataset
        pretrained = "gdrive:{}-snapshot.pkl".format(task)
        if pretrained not in pretrained_networks.gdrive_urls:
            pretrained = None

        nset(args, "recompile", pretrained is not None)
        nset(args, "pretrained_pkl", pretrained)
        nset(args, "mirror_augment", task in ["cityscapes", "ffhq"])

        nset(args, "transformer", True)
        nset(args, "components_num", {"clevr": 8}.get(task, 16))
        nset(args, "latent_size", {"clevr": 128}.get(task, 512))

        nset(args, "normalize", "layer")
        nset(args, "integration", "mul")
        nset(args, "kmeans", True)
        nset(args, "use_pos", True)
        nset(args, "mapping_ltnt2ltnt", task != "clevr")
        nset(args, "style", task != "clevr")

        nset(args, "g_arch", "resnet")
        nset(args, "mapping_resnet", True)

        gammas = {"ffhq": 10, "cities": 20, "clevr": 40, "bedrooms": 100}
        nset(args, "gamma", gammas.get(task, 10))

    if args.baseline == "GAN":
        nset(args, "style", False)
        nset(args, "latent_stem", True)

    if args.baseline == "SAGAN":
        nset(args, "style", False)
        nset(args, "latent_stem", True)
        nset(args, "g_img2img", 5)

    if args.baseline == "kGAN":
        nset(args, "kgan", True)
        nset(args, "merge_layer", 5)
        nset(args, "merge_type", "softmax")
        nset(args, "components_num", 8)

    # Environment configuration
    tf_config = {
        "rnd.np_random_seed": 1000,
        "allow_soft_placement": True,
        "gpu_options.per_process_gpu_memory_fraction": 1.0
    }
    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    # Networks configuration
    cG = set_net("G", reg_interval=4)
    cD = set_net("D", reg_interval=16)

    # Dataset configuration
    # For bedrooms, we choose the most common ratio in the
    # dataset and crop the other images into that ratio.
    ratios = {
        "clevr": 0.75,
        "bedrooms": 188 / 256,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = ratios.get(args.dataset, args.ratio)
    dataset_args = EasyDict(tfrecord_dir=args.dataset,
                            max_imgs=args.train_images_num,
                            num_threads=args.num_threads)
    for arg in ["data_dir", "mirror_augment", "total_kimg", "ratio"]:
        cset(train, arg, args[arg])

    # Training and Optimizations configuration
    for arg in ["eval", "train", "recompile", "last_snapshots"]:
        cset(train, arg, args[arg])

    # Round to the closest multiply of minibatch size for validity
    args.batch_size -= args.batch_size % args.minibatch_size
    args.minibatch_std_size -= args.minibatch_std_size % args.minibatch_size
    args.latent_size -= args.latent_size % args.components_num
    if args.latent_size == 0:
        misc.error(
            "--latent-size is too small. Must best a multiply of components-num"
        )

    sched_args = {
        "G_lrate": "g_lr",
        "D_lrate": "d_lr",
        "minibatch_size": "batch_size",
        "minibatch_gpu": "minibatch_size"
    }
    for arg, cmd_arg in sched_args.items():
        cset(sched, arg, args[cmd_arg])
    cset(train, "clip", args.clip)

    # Logging and metrics configuration
    metrics = [metric_defaults[x] for x in args.metrics]

    cset(cG.args, "truncation_psi", args.truncation_psi)
    for arg in ["keep_samples", "num_heads"]:
        cset(vis, arg, args[arg])
    for arg in ["summarize", "eval_images_num"]:
        cset(train, arg, args[arg])

    # Visualization
    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = {arg for arg in vis_types if args["vis_{}".format(arg)]}

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "interpolation_density",
        # "intrp_per_component": "interpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks architecture
    cset(cG.args, "architecture", args.g_arch)
    cset(cD.args, "architecture", args.d_arch)
    cset(cG.args, "tanh", args.tanh)

    # Latent sizes
    if args.components_num > 1:
        if not (args.transformer or args.kgan):
            misc.error(
                "--components-num > 1 but the model is not using components. "
                +
                "Either add --transformer for GANsformer or --kgan for k-GAN.")

        args.latent_size = int(args.latent_size / args.components_num)
    cD.args.latent_size = cG.args.latent_size = cG.args.dlatent_size = args.latent_size
    cset([cG.args, cD.args, vis], "components_num", args.components_num)

    # Mapping network
    for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]:
        field = "mapping_{}".format(arg)
        cset(cG.args, field, args[field])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]:
        cset(cG.args, arg, args[arg])
    cD.args.mbstd_group_size = args.minibatch_std_size

    # GANsformer
    cset(cG.args, "transformer", args.transformer)
    cset(cD.args, "transformer", args.d_transformer)

    args.norm = args.normalize
    for arg in [
            "norm", "integration", "ltnt_gate", "img_gate", "iterative",
            "kmeans", "kmeans_iters", "mapping_ltnt2ltnt"
    ]:
        cset(cG.args, arg, args[arg])

    for arg in ["use_pos", "num_heads"]:
        cset([cG.args, cD.args], arg, args[arg])

    # Positional encoding
    for arg in ["dim", "init", "directions_num"]:
        field = "pos_{}".format(arg)
        cset([cG.args, cD.args], field, args[field])

    # k-GAN
    for arg in ["layer", "type", "same"]:
        field = "merge_{}".format(arg)
        cset(cG.args, field, args[field])
    cset([cG.args, train], "merge", args.kgan)

    if args.kgan and args.transformer:
        misc.error(
            "Either have --transformer for GANsformer or --kgan for k-GAN, not both"
        )

    # Attention
    for arg in ["start_res", "end_res", "ltnt2ltnt",
                "img2img"]:  # , "local_attention"
        cset(cG.args, arg, args["g_{}".format(arg)])
        cset(cD.args, arg, args["d_{}".format(arg)])
    cset(cG.args, "img2ltnt", args.g_img2ltnt)
    # cset(cD.args, "ltnt2img", args.d_ltnt2img)

    # Mixing and dropout
    for arg in [
            "style_mixing", "component_mixing", "component_dropout",
            "attention_dropout"
    ]:
        cset(cG.args, arg, args[arg])

    # Loss and regularization
    gloss_args = {
        "loss_type": "g_loss",
        "reg_weight": "g_reg_weight",
        # "pathreg": "pathreg",
    }
    dloss_args = {"loss_type": "d_loss", "reg_type": "d_reg", "gamma": "gamma"}
    for arg, cmd_arg in gloss_args.items():
        cset(cG.loss_args, arg, args[cmd_arg])
    for arg, cmd_arg in dloss_args.items():
        cset(cD.loss_args, arg, args[cmd_arg])

    ##### Experiments management:
    # Whenever we start a new experiment we store its result in a directory named 'args.expname:000'.
    # When we rerun a training or evaluation command it restores the model from that directory by default.
    # If we wish to restart the model training, we can set --restart and then we will store data in a new
    # directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc.

    # Find the latest directory that matches the experiment
    exp_dir = sorted(glob.glob("{}/{}-*".format(args.result_dir,
                                                args.expname)))
    run_id = 0
    if len(exp_dir) > 0:
        run_id = int(exp_dir[-1].split("-")[-1])
    # If restart, then work over a new directory
    if args.restart:
        run_id += 1

    run_name = "{}-{:03d}".format(args.expname, run_id)
    train.printname = "{} ".format(misc.bold(args.expname))

    snapshot, kimg, resume = None, 0, False
    pkls = sorted(
        glob.glob("{}/{}/network*.pkl".format(args.result_dir, run_name)))
    # Load a particular snapshot is specified
    if args.pretrained_pkl is not None and args.pretrained_pkl != "None":
        # Soft links support
        if args.pretrained_pkl.startswith("gdrive"):
            if args.pretrained_pkl not in pretrained_networks.gdrive_urls:
                misc.error(
                    "--pretrained_pkl {} not available in the catalog (see pretrained_networks.py)"
                )

            snapshot = args.pretrained_pkl
        else:
            snapshot = glob.glob(args.pretrained_pkl)[0]
            if os.path.islink(snapshot):
                snapshot = os.readlink(snapshot)

        # Extract training step from the snapshot if specified
        try:
            kimg = int(snapshot.split("-")[-1].split(".")[0])
        except:
            pass

    # Find latest snapshot in the directory
    elif len(pkls) > 0:
        snapshot = pkls[-1]
        kimg = int(snapshot.split("-")[-1].split(".")[0])
        resume = True

    if snapshot:
        misc.log(
            "Resuming {}, from {}, kimg {}".format(run_name, snapshot, kimg),
            "white")
        train.resume_pkl = snapshot
        train.resume_kimg = kimg
    else:
        misc.log("Start model training from scratch", "white")

    # Run environment configuration
    sc.run_dir_root = args.result_dir
    sc.run_desc = args.expname
    sc.run_id = run_id
    sc.run_name = run_name
    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True

    kwargs = EasyDict(train)
    kwargs.update(cG=cG, cD=cD)
    kwargs.update(dataset_args=dataset_args,
                  vis_args=vis,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.resume = resume
    kwargs.load_config = args.reload

    dnnlib.submit_run(**kwargs)
Example #10
0
def training_loop(
    # General configuration
    train                   = False,    # Training mode
    eval                    = False,    # Evaluation mode
    vis                     = False,    # Visualization mode        
    run_dir                 = ".",      # Output directory
    num_gpus                = 1,        # Number of GPUs participating in the training
    rank                    = 0,        # Rank of the current process in [0, num_gpus]
    cG                      = {},       # Options for generator network
    cD                      = {},       # Options for discriminator network
    # Data
    dataset_args            = {},       # Options for training set
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks
    # Optimization
    loss_args               = {},       # Options for loss function
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus
    batch_gpu               = 4,        # Number of samples processed at a time by one GPU
    ema_kimg                = 10.0,     # Half-life of the exponential moving average (EMA) of generator weights
    ema_rampup              = None,     # EMA ramp-up coefficient
    cudnn_benchmark         = True,     # Enable torch.backends.cudnnbenchmark?
    allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnnallow_tf32?
    # Logging
    resume_pkl              = None,     # Network pickle to resume training from
    resume_kimg             = 0.0,      # Assumed training progress at the beginning
                                        # Affects reporting and training schedule    
    kimg_per_tick           = 8,        # Progress snapshot interval
    img_snapshot_ticks      = 3,        # How often to save image snapshots? None = disable
    network_snapshot_ticks  = 3,        # How often to save network snapshots? None = disable    
    last_snapshots          = 10,       # Maximal number of prior snapshots to save
    printname               = "",       # Experiment name for logging
    # Evaluation
    vis_args                = {},       # Options for vis.vis
    metrics                 = [],       # Metrics to evaluate during training
    eval_images_num         = 50000,    # Sample size for the metrics
    truncation_psi          = 0.7       # Style strength multiplier for the truncation trick (used for visualizations only)
):  
    # Initialize
    start_time = time.time()
    device = init_cuda(rank, cudnn_benchmark, allow_tf32)
    log = (rank == 0)

    dataset, dataset_iter = load_dataset(dataset_args, batch_size, rank, num_gpus, log) # Load training set

    nets = construct_nets(cG, cD, dataset, device, log) if train else None        # Construct networks
    G, D, Gs = load_nets(resume_pkl, nets, device, log)                           # Resume from existing pickle
    print_nets(G, D, batch_gpu, device, log)                                      # Print network summary tables

    if eval:
        misc.log("Run evaluation...", log = log)
        evaluate(Gs, resume_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log)

    if vis and log:
        misc.log("Produce visualizations...")
        visualize.vis(Gs, dataset, device, batch_gpu, drange_net = drange_net, ratio = dataset.ratio, 
            truncation_psi = truncation_psi, **vis_args)

    if not train:
        exit()

    nets = distribute_nets(G, D, Gs, device, num_gpus, log)                                   # Distribute networks across GPUs
    loss, stages = setup_training_stages(loss_args, G, cG, D, cD, nets, device, log)          # Setup training stages (losses and optimizers)
    grid_size, grid_z, grid_c = init_img_grid(dataset, G.input_shape, device, run_dir, log)   # Initialize an image grid    
    logger = init_logger(run_dir, log)                                                        # Initialize logs

    # Train
    misc.log(f"Training for {total_kimg} kimg...", "white", log)    
    cur_nimg, cur_tick, batch_idx = int(resume_kimg * 1000), 0, 0
    tick_start_nimg, tick_start_time = cur_nimg, time.time()
    stats = None

    while True:
        # Fetch training data
        real_img, real_c, gen_zs, gen_cs = fetch_data(dataset, dataset_iter, G.input_shape, drange_net, 
            device, len(stages), batch_size, batch_gpu)

        # Execute training stages
        for stage, gen_z, gen_c in zip(stages, gen_zs, gen_cs):
            if batch_idx % stage.interval != 0:
                continue
            run_training_stage(loss, stage, device, real_img, real_c, gen_z, gen_c, batch_size, batch_gpu, num_gpus)

        # Update Gs
        update_ema_network(G, Gs, batch_size, cur_nimg, ema_kimg, ema_rampup)

        # Update state
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line and accumulate the info in logger.collector
        tick_end_time = time.time()
        if stats is not None:
            default = dnnlib.EasyDict({'mean': -1})
            fields = []
            fields.append("tick " + misc.bold(f"{training_stats.report0('Progress/tick', cur_tick):<5d}"))
            fields.append("kimg " + misc.bcolored(f"{training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}", "red"))
            fields.append("")
            fields.append("loss/reg: G (" + misc.bcolored(f"{stats.get('Loss/G/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/G/reg', default).mean:>6.3f}") + ")")
            fields.append("D "+ misc.bcolored(f"({stats.get('Loss/D/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/D/reg', default).mean:>6.3f}") + ")")
            fields.append("")
            fields.append("time " + misc.bold(f"{dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"))
            fields.append(f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}")
            fields.append(f"mem: GPU {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}")
            fields.append(f"CPU {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}")
            fields.append(misc.bold(printname))
            torch.cuda.reset_peak_memory_stats()
            training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
            training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
            misc.log(" ".join(fields), log = log)

        # Save image snapshot
        if log and (img_snapshot_ticks is not None) and (done or cur_tick % img_snapshot_ticks == 0):
            visualize.vis(Gs, dataset, device, batch_gpu, training = True,
                step = cur_nimg // 1000, grid_size = grid_size, latents = grid_z, 
                labels = grid_c, drange_net = drange_net, ratio = dataset.ratio, **vis_args)

        # Save network snapshot
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data, snapshot_pkl = save_nets(G, D, Gs, cur_nimg, dataset_args, run_dir, num_gpus > 1, last_snapshots, log)

            # Evaluate metrics
            evaluate(snapshot_data["Gs"], snapshot_pkl, metrics, eval_images_num,
                dataset_args, num_gpus, rank, device, log, logger, run_dir)
            del snapshot_data

        # Collect stats and update logs
        stats = collect_stats(logger, stages)
        update_logger(logger, stats, cur_nimg, start_time)

        cur_tick += 1
        tick_start_nimg, tick_start_time = cur_nimg, time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done
    misc.log("Done!", "blue")
Example #11
0
def run_cmdline(argv):
    parser = argparse.ArgumentParser(
        prog=argv[0],
        description="Download and prepare data for the GANsformer.")
    parser.add_argument("--data-dir",
                        help="Directory of created dataset",
                        default="datasets",
                        type=str)
    parser.add_argument(
        "--shards-num",
        help="Number of shards to split each dataset to (optional)",
        default=1,
        type=int)
    parser.add_argument(
        "--max-images",
        help=
        "Maximum number of images to have in the dataset (optional). Use to reduce the produced tfrecords file size",
        default=None,
        type=int)
    # Default tasks
    parser.add_argument(
        "--clevr",
        help=
        "Prepare the CLEVR dataset (18GB download, up to 15.5GB tfrecords, 100k images)",
        dest="tasks",
        action="append_const",
        const="clevr")
    parser.add_argument(
        "--bedrooms",
        help=
        "Prepare the LSUN-bedrooms dataset (42.8GB, up to 480GB tfrecords, 3M images)",
        dest="tasks",
        action="append_const",
        const="bedrooms")
    parser.add_argument(
        "--ffhq",
        help=
        "Prepare the FFHQ dataset (13GB download, 13GB tfrecords, 70k images)",
        dest="tasks",
        action="append_const",
        const="ffhq")
    parser.add_argument(
        "--cityscapes",
        help=
        "Prepare the cityscapes dataset (1.8GB, 8GB tfrecords, 25k images)",
        dest="tasks",
        action="append_const",
        const="cityscapes")
    # Create a new task with custom images
    parser.add_argument("--task",
                        help="New dataset name",
                        type=str,
                        dest="tasks",
                        action="append")
    parser.add_argument(
        "--images-dir",
        help=
        "Provide source image directory to convert into tfrecords (will be searched recursively)",
        default=None,
        type=str)
    parser.add_argument("--format",
                        help="Images format",
                        default=None,
                        choices=["png", "jpg", "npy", "hdf5", "tfds", "lmdb"],
                        type=str)
    parser.add_argument("--ratio",
                        help="Images height/width",
                        default=1.0,
                        type=float)

    args = parser.parse_args()
    if not args.tasks:
        misc.error("No tasks specified. Please see '-h' for help.")

    if args.max_images < 50000:
        misc.log(
            "Warning: max-images is set to {}. We recommend setting it at least to 50,000 to allow statistically correct computation of the FID-50k metric."
            .format(args.max_images), "red")

    prepare(**vars(args))
Example #12
0
def training_loop(
    # Configurations
    cG={},
    cD={},  # Generator and Discriminator command-line arguments
    dataset_args={},  # dataset.load_dataset() options
    sched_args={},  # train.TrainingSchedule options
    vis_args={},  # vis.eval options
    grid_args={},  # train.setup_snapshot_img_grid() options
    metric_arg_list=[],  # MetricGroup Options
    tf_config={},  # tflib.init_tf() options
    eval=False,  # Evaluation mode
    train=False,  # Training mode
    # Data
    data_dir=None,  # Directory to load datasets from
    total_kimg=25000,  # Total length of the training, measured in thousands of real images
    mirror_augment=False,  # Enable mirror augmentation?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks
    ratio=1.0,  # Image height/width ratio in the dataset
    # Optimization
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters
    lazy_regularization=True,  # Perform regularization as a separate training step?
    smoothing_kimg=10.0,  # Half-life of the running average of generator weights
    clip=None,  # Clip gradients threshold
    # Resumption
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning
    # Affects reporting and training schedule
    resume_time=0.0,  # Assumed wallclock time at the beginning, affects reporting
    recompile=False,  # Recompile network from source code (otherwise loads from snapshot)
    # Logging
    summarize=True,  # Create TensorBoard summaries
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    img_snapshot_ticks=3,  # How often to save image snapshots? None = disable
    network_snapshot_ticks=3,  # How often to save network snapshots? None = only save networks-final.pkl
    last_snapshots=10,  # Maximal number of prior snapshots to save
    eval_images_num=50000,  # Sample size for the metrics
    printname="",  # Experiment name for logging
    # Architecture
    merge=False):  # Generate several images and then merge them

    # Initialize dnnlib and TensorFlow
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus
    cG.name, cD.name = "g", "d"

    # Load dataset, configure training scheduler and metrics object
    dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                verbose=True,
                                **dataset_args)
    sched = training_schedule(sched_args,
                              cur_nimg=total_kimg * 1000,
                              dataset=dataset)
    metrics = metric_base.MetricGroup(metric_arg_list)

    # Construct or load networks
    with tf.device("/gpu:0"):
        no_op = tf.no_op()
        G, D, Gs = None, None, None
        if resume_pkl is None or recompile:
            misc.log("Constructing networks...", "white")
            G = tflib.Network("G",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cG.args)
            D = tflib.Network("D",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cD.args)
            Gs = G.clone("Gs")
        if resume_pkl is not None:
            G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile)

    G.print_layers()
    D.print_layers()

    # Train/Evaluate/Visualize
    # Labels are optional but not essential
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid(
        dataset, **grid_args)
    misc.save_img_grid(grid_reals,
                       dnnlib.make_run_dir_path("reals.png"),
                       drange=dataset.dynamic_range,
                       grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])

    if eval:
        # Save a snapshot of the current network to evaluate
        pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" %
                                       resume_kimg)
        misc.save_pkl((G, D, Gs), pkl, remove=False)

        # Quantitative evaluation
        metric = metrics.run(pkl,
                             num_imgs=eval_images_num,
                             run_dir=dnnlib.make_run_dir_path(),
                             data_dir=dnnlib.convert_path(data_dir),
                             num_gpus=num_gpus,
                             ratio=ratio,
                             tf_config=tf_config,
                             mirror_augment=mirror_augment)

        # Qualitative evaluation
        visualize.eval(G,
                       dataset,
                       batch_size=sched.minibatch_gpu,
                       drange_net=drange_net,
                       ratio=ratio,
                       **vis_args)

    if not train:
        dataset.close()
        exit()

    # Setup training inputs
    misc.log("Building TensorFlow graph...", "white")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[])
        lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[])
        step = tf.placeholder(tf.int32, name="step", shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name="minibatch_size_in",
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name="minibatch_gpu_in",
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                   tf.float32), smoothing_kimg *
                           1000.0) if smoothing_kimg > 0.0 else 0.0

    # Set optimizers
    for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]:
        set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip)

    # Build training graph for each GPU
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):

            # Create GPU-specific shadow copies of G and D
            for cN, N in [(cG, G), (cD, D)]:
                cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow")
            Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow")

            # Fetch training data via temporary variables
            with tf.name_scope("DataFetch"):
                reals, labels = dataset.get_minibatch_tf()
                reals = process_reals(reals, dataset.dynamic_range, drange_net,
                                      mirror_augment)
                reals, reals_fetch = read_data(
                    reals, "reals", [sched.minibatch_gpu] + dataset.shape,
                    minibatch_gpu_in)
                labels, labels_fetch = read_data(
                    labels, "labels",
                    [sched.minibatch_gpu, dataset.label_size],
                    minibatch_gpu_in)
                data_fetch_ops += [reals_fetch, labels_fetch]

            # Evaluate loss functions
            with tf.name_scope("G_loss"):
                cG.loss, cG.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    minibatch_size=minibatch_gpu_in,
                    **cG.loss_args)

            with tf.name_scope("D_loss"):
                cD.loss, cD.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    labels=labels,
                    minibatch_size=minibatch_gpu_in,
                    **cD.loss_args)

            for cN in [cG, cD]:
                set_optimizer_ops(cN, lazy_regularization, no_op)

    # Setup training ops
    data_fetch_op = tf.group(*data_fetch_ops)
    for cN in [cG, cD]:
        cN.train_op = cN.opt.apply_updates()
        cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta)

    # Finalize graph
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # Tensorboard summaries
    if summarize:
        misc.log("Initializing logs...", "white")
        summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
        if save_tf_graph:
            summary_log.add_graph(tf.get_default_graph())
        if save_weight_histograms:
            G.setup_weight_histograms()
            D.setup_weight_histograms()

    # Initialize training
    misc.log("Training for %d kimg..." % total_kimg, "white")
    dnnlib.RunContext.get().update("",
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()

    cur_tick, running_mb_counter = -1, 0
    cur_nimg = int(resume_kimg * 1000)
    tick_start_nimg = cur_nimg
    for cN in [cG, cD]:
        cN.lossvals_agg = {
            k: None
            for k in ["loss", "reg", "norm", "reg_norm"]
        }
        cN.opt.reset_optimizer_state()

    # Training loop
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops
        sched = training_schedule(sched_args,
                                  cur_nimg=cur_nimg,
                                  dataset=dataset)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        dataset.configure(sched.minibatch_gpu)

        # Run training ops
        feed_dict = {
            lrate_in_g: sched.G_lrate,
            lrate_in_d: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            step: sched.kimg
        }

        # Several iterations before updating training parameters
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            for cN in [cG, cD]:
                cN.run_reg = lazy_regularization and (running_mb_counter %
                                                      cN.reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            for cN in [cG, cD]:
                cN.lossvals = {
                    k: None
                    for k in ["loss", "reg", "norm", "reg_norm"]
                }

            # Gradient accumulation
            for _round in rounds:
                cG.lossvals.update(
                    tflib.run([cG.train_op, cG.ops], feed_dict)[1])
                if cG.run_reg:
                    _, cG.lossvals["reg_norm"] = tflib.run(
                        [cG.reg_op, cG.reg_norm], feed_dict)

                tflib.run(data_fetch_op, feed_dict)

                cD.lossvals.update(
                    tflib.run([cD.train_op, cD.ops], feed_dict)[1])
                if cD.run_reg:
                    _, cD.lossvals["reg_norm"] = tflib.run(
                        [cD.reg_op, cD.reg_norm], feed_dict)

            tflib.run([Gs_update_op], feed_dict)

            # Track loss statistics
            for cN in [cG, cD]:
                for k in cN.lossvals_agg:
                    cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k],
                                                cN.lossvals[k])

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress
            print(
                ("tick %s kimg %s   loss/reg: G (%s %s) D (%s %s)   grad norms: G (%s %s) D (%s %s)   "
                 + "time %s sec/kimg %s maxGPU %sGB %s") %
                (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)),
                 misc.bcolored(
                     "{:>8.1f}".format(
                         autosummary("Progress/kimg", cur_nimg / 1000.0)),
                     "red"),
                 misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)),
                 misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)),
                 misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.bold("%-10s" % dnnlib.util.format_time(
                     autosummary("Timing/total_sec", total_time))),
                 "{:>7.2f}".format(
                     autosummary("Timing/sec_per_kimg",
                                 tick_time / tick_kimg)),
                 "{:>4.1f}".format(
                     autosummary("Resources/peak_gpu_mem_gb",
                                 peak_gpu_mem_op.eval() / 2**30)), printname))

            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots
            if img_snapshot_ticks is not None and (
                    cur_tick % img_snapshot_ticks == 0 or done):
                visualize.eval(G,
                               dataset,
                               batch_size=sched.minibatch_gpu,
                               training=True,
                               step=cur_nimg // 1000,
                               grid_size=grid_size,
                               latents=grid_latents,
                               labels=grid_labels,
                               drange_net=drange_net,
                               ratio=ratio,
                               **vis_args)

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl, remove=False)

                if cur_tick % network_snapshot_ticks == 0 or done:
                    metric = metrics.run(
                        pkl,
                        num_imgs=eval_images_num,
                        run_dir=dnnlib.make_run_dir_path(),
                        data_dir=dnnlib.convert_path(data_dir),
                        num_gpus=num_gpus,
                        ratio=ratio,
                        tf_config=tf_config,
                        mirror_augment=mirror_augment)

                if last_snapshots > 0:
                    misc.rm(
                        sorted(
                            glob.glob(dnnlib.make_run_dir_path(
                                "network*.pkl")))[:-last_snapshots])

            # Update summaries and RunContext
            if summarize:
                metrics.update_autosummaries()
                tflib.autosummary.save_summaries(summary_log, cur_nimg)

            dnnlib.RunContext.get().update(None,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot
    misc.save_pkl((G, D, Gs),
                  dnnlib.make_run_dir_path("network-final.pkl"),
                  remove=False)

    # All done
    if summarize:
        summary_log.close()
    dataset.close()
Example #13
0
def vis(G,
    dataset,                          # The dataset object for accessing the data
    device,                           # Device to run visualization on
    batch_size,                       # Visualization batch size
    run_dir             = ".",        # Output directory
    training            = False,      # Training mode
    latents             = None,       # Source latents to generate images from
    labels              = None,       # Source labels to generate images from (0 if no labels are used)
    ratio               = 1.0,        # Image height/width ratio in the dataset
    truncation_psi      = 0.7,        # Style strength multiplier for the truncation trick (used for visualizations only)
    # Model settings
    k                   = 1,          # Number of components the model has
    drange_net          = [-1,1],     # Model image output range
    attention           = False,      # Whereas the model produces attention maps (for visualization)
    num_heads           = 1,          # Number of attention heads
    # Visualization settings
    vis_types           = None,       # Visualization types to be created
    num                 = 100,        # Number of produced samples
    rich_num            = 5,          # Number of samples for which richer visualizations will be created
                                      # (requires more memory and disk space, and therefore rich_num <= num)
    grid                = None,       # Whether to save the samples in one large grid files
                                      # or in separated files one per sample
    grid_size           = None,       # Grid proportions (w, h)
    step                = None,       # Step number to be used in visualization filenames
    verbose             = None,       # Verbose print progress messages
    keep_samples        = True,       # Keep all prior samples during training
    # Visualization-specific settings
    alpha               = 0.3,        # Proportion for generated images and attention maps blends
    intrp_density       = 8,          # Number of samples in between two end points of an interpolation
    intrp_per_component = False,      # Whether to perform interpolation along particular latent components (True)
                                      # or all of them at once (False)
    noise_samples_num   = 100,        # Number of samples used to compute noise variation visualization
    section_size        = 100):       # Visualization section size (section_size <= num) for reducing memory footprint
    
    def prefix(step): return "" if step is None else f"{step:06d}_"
    def pattern_of(dir, step, suffix): return f"{run_dir}/visuals/{dir}/{prefix(step)}%06d.{suffix}"

    # Set default options
    if verbose is None: verbose = not training # Disable verbose during training
    if grid is None: grid = training # Save image samples in one grid file during training    
    if grid_size is not None: section_size = rich_num = num = np.prod(grid_size) # If grid size is provided, set images number accordingly

    _labels, _latents = labels, latents
    if _latents is not None: assert num == _latents.shape[0]
    if _labels  is not None: assert num == _labels.shape[0]
    assert rich_num <= section_size

    vis = vis_types
    # For time efficiency, during training save only image and map samples rather than richer visualizations
    if training:
        vis = {"imgs"} # , "maps"
        # if num_heads == 1:
        #    vis.add("layer_maps")
    else:
        vis = vis or {"imgs", "maps", "ltnts", "interpolations", "noise_var"}

    # Build utility functions
    save_images = misc.save_images_builder(drange_net, ratio, grid_size, grid, verbose)
    save_blends = misc.save_blends_builder(drange_net, ratio, grid_size, grid, verbose, alpha)

    crange = trange if verbose else range
    section_of = lambda a, i, n: a[i * n: (i + 1) * n]

    get_rnd_latents = lambda n: torch.randn([n, *G.input_shape[1:]], device = device)
    get_rnd_labels = lambda n: torch.from_numpy(dataset.get_random_labels(n)).to(device)

    # Create directories
    dirs = []
    if "imgs" in vis:            dirs += ["images"]
    if "ltnts" in vis:           dirs += ["latents-z", "latents-w"]
    if "maps" in vis:            dirs += ["maps", "softmaps", "blends", "softblends"]
    if "layer_maps" in vis:      dirs += ["layer_maps"]
    if "interpolations" in vis:  dirs += ["interpolations-z", "interpolation-w"]

    if not keep_samples:
        shutil.rmtree(f"{run_dir}/visuals")
    for dir in dirs:
        os.makedirs(f"{run_dir}/visuals/{dir}", exist_ok = True)

    if verbose:
        print("Running network and saving samples...")
    
    # Produce visualizations
    for idx in crange(0, num, section_size):
        curr_size = curr_section_size(num, idx, section_size)

        # Compute source latents/labels that images will be produced from
        latents = get_rnd_latents(curr_size) if _latents is None else section_of(_latents, idx, section_size)
        labels  = get_rnd_labels(curr_size)  if _labels  is None else section_of(_labels,  idx, section_size)
        if idx == 0:
            latents0, labels0 = latents, labels
        # Run network over latents and produce images and attention maps
        ret = run(G, latents, labels, batch_size, truncation_psi, noise_mode = "const", 
            return_att = True, return_ws = True)
        # For memory efficiency, save full information only for a small amount of images
        images, attmaps_all_layers, wlatents_all_layers = ret
        soft_maps = attmaps_all_layers[:,:,-1,0] if attention else None
        attmaps_all_layers = attmaps_all_layers[:rich_num]
        wlatents = wlatents_all_layers[:,:,0]
        # Save image samples
        if "imgs" in vis:
            save_images(images, pattern_of("images", step, "png"), idx)

        # Save latent vectors
        if "ltnts" in vis:
            misc.save_npys(latents, pattern_of("latents-z", step, "npy"), verbose, idx)
            misc.save_npys(wlatents, pattern_of("latents-w", step, "npy"), verbose, idx)

        # For the GANformer model, save attention maps
        if attention:
            if "maps" in vis:
                pallete = np.expand_dims(misc.get_colors(k - 1), axis = [2, 3])
                maps = (soft_maps == np.amax(soft_maps, axis = 1, keepdims = True)).astype(float)

                soft_maps = np.sum(pallete * np.expand_dims(soft_maps, axis = 2), axis = 1)
                maps = np.sum(pallete * np.expand_dims(maps, axis = 2), axis = 1)

                save_images(soft_maps, pattern_of("softmaps", step, "png"), idx)
                save_images(maps, pattern_of("maps", step, "png"), idx)

                save_blends(soft_maps, images, pattern_of("softblends", step, "png"), idx)
                save_blends(maps, images, pattern_of("blends", step, "png"), idx)

            # Save maps from all attention heads and layers
            # (for efficiency, only for a small number of images)
            if "layer_maps" in vis:
                all_maps = []
                maps_fakes = np.split(attmaps_all_layers, attmaps_all_layers.shape[2], axis = 2)
                for layer, lmaps in enumerate(maps_fakes):
                    lmaps = np.split(np.squeeze(lmaps, axis = 2), lmaps.shape[2], axis = 2)
                    for head, hmap in enumerate(lmaps):
                        hmap = (hmap == np.amax(hmap, axis = 1, keepdims = True)).astype(float)
                        hmap = np.sum(pallete * hmap, axis = 1)
                        all_maps.append((hmap, f"l{layer}_h{head}"))

                if not grid:
                    for i in range(rich_num):
                        stepdir = "" if step is None else (f"/{step:06d}")
                        os.makedirs(f"{run_dir}/visuals/layer_maps/%06d" % i + stepdir, exist_ok = True)
                for maps, name in all_maps:
                    if grid:
                        pattern = f"{run_dir}/visuals/layer_maps/{prefix(step)}%06d-{name}.png"
                    else:
                        pattern = f"{run_dir}/visuals/layer_maps/%06d/{stepdir}/{name}.png"
                    save_images(maps, pattern, idx)

    # Produce interpolations between pairs or source latents
    # In the GANformer case, varying one component at a time
    if "interpolations" in vis:
        ts = torch.linspace(0.0, 1.0, steps = intrp_density)

        if verbose:
            print("Generating interpolations...")
        for i in crange(rich_num):
            os.makedirs(f"{run_dir}/visuals/interpolations-z/%06d" % i, exist_ok = True)
            os.makedirs(f"{run_dir}/visuals/interpolations-w/%06d" % i, exist_ok = True)

            z = get_rnd_latents(2)
            z[0] = latents0[i]
            c = labels0[i:i+1]
            
            w = run(G, z, c, batch_size, truncation_psi, noise_mode = "const", return_ws = True)[-1]

            def update(t, fn, ts, dim):
                if dim == 3:
                    ts = ts[:, None]
                t_ups = []

                if intrp_per_component:
                    for c in range(k - 1):
                        # copy over all the components except component c that will get interpolated
                        t_up = torch.clone(t[0]).unsqueeze(0).repeat((intrp_density, ) + tuple([1] * dim))
                        # interpolate component c
                        t_up[:,c] = fn(t[0, c], t[1, c], ts)
                        t_ups.append(t_up)

                    t_up = torch.cat(t_ups, dim = 0)
                else:
                    t_up = fn(t[0], t[1], ts.unsqueeze(1))

                return t_up

            z_up = update(z, slerp, ts, 2)
            w_up = update(w, lerp, ts, 3)

            imgs1 = run(G, z_up, c, batch_size, truncation_psi, noise_mode = "const")[0]
            imgs2 = run(G, w_up, c, batch_size, truncation_psi, noise_mode = "const", take_w = True)[0]

            def save_interpolation(imgs, name):
                imgs = np.split(imgs, k - 1, axis = 0)
                for c in range(k - 1):
                    filename = f"{run_dir}/visuals/interpolations-{name}/{i:06d}/{c:02d}"
                    imgs[c] = [misc.to_pil(img, drange = drange_net) for img in imgs[c]]
                    imgs[c][-1].save(f"{filename}.png")
                    misc.save_gif(imgs[c], f"{filename}.gif")

            save_interpolation(imgs1, "z")
            save_interpolation(imgs2, "w")

    # Compute noise variance map
    # Shows what areas vary the most given fixed source
    # latents due to the use of stochastic local noise
    if "noise_var" in vis:
        if verbose:
            print("Generating noise variance...")

        z = get_rnd_latents(1).repeat(noise_samples_num, 1, 1)
        c = get_rnd_labels(1)
        imgs = run(G, z, c, batch_size, truncation_psi)[0]
        imgs = np.stack([misc.to_pil(img, drange = drange_net) for img in imgs], axis = 0)
        diff = np.std(np.mean(imgs, axis = 3), axis = 0) * 4
        diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)
        PIL.Image.fromarray(diff, "L").save(f"{run_dir}/visuals/noise-variance.png")

    # Compute style mixing table, varying using the latent A in some of the layers and latent B in rest.
    # For the GANformer, also produce component mixes (using latents from A in some of the components,
    # and latents from B in the rest.
    if "style_mix" in vis:
        if verbose:
            print("Generating style mixes...")
        cols, rows = 4, 2
        row_lens = np.array([2, 5, 8, 11])
        c = get_rnd_labels(1)

        # Create latent mixes
        mixes = {
            "layer": (np.arange(wlatents_all_layers.shape[2]) < row_lens[:,None]).astype(np.float32)[:,None,None,None,:,None],
            "component": (np.arange(wlatents_all_layers.shape[1]) < row_lens[:,None]).astype(np.float32)[:,None,None,:,None,None]
        }
        ws = wlatents_all_layers[:cols+rows]
        orig_imgs = images[:cols+rows]
        col_z = wlatents_all_layers[:cols][None, None]
        row_z = wlatents_all_layers[cols:cols+rows][None,:,None]

        for name, mix in mixes.items():
            # Produce image mixes
            mix_z = mix * row_z  + (1 - mix) * col_z
            mix_z = torch.from_numpy(np.reshape(mix_z, [-1, *wlatents_all_layers.shape[1:]])).to(device)
            mix_imgs = run(G, mix_z, c, batch_size, truncation_psi, noise_mode = "const", take_w = True)[0]
            mix_imgs = np.reshape(mix_imgs, [len(row_lens) * rows, cols, *mix_imgs.shape[1:]])

            # Create image table canvas
            H, W = mix_imgs.shape[-2:]
            canvas = PIL.Image.new("RGB", (W * (cols + 1), H * (len(row_lens) * rows + 1)), "black")

            # Place image mixes respectively at each position (row_idx, col_idx)
            for row_idx, row_elem in enumerate([None] + list(range(len(row_lens) * rows))):
                for col_idx, col_elem in enumerate([None] + list(range(cols))):
                    if (row_elem, col_elem) == (None, None):  continue
                    if row_elem is None:                    img = orig_imgs[col_elem]
                    elif col_elem is None:                  img = orig_imgs[cols + (row_elem % rows)]
                    else:                                   img = mix_imgs[row_elem, col_elem]

                    canvas.paste(misc.to_pil(img, drange = drange_net), (W * col_idx, H * row_idx))

            canvas.save(f"{run_dir}/visuals/{name}-mixing.png")

    if verbose:
        misc.log("Visualizations Completed!", "blue")
Example #14
0
def setup_config(run_dir, **args):
    args = EasyDict(args)  # command-line options
    train = EasyDict(run_dir=run_dir)  # training loop options
    vis = EasyDict(run_dir=run_dir)  # visualization loop options

    if args.reload:
        config_fn = os.path.join(run_dir, "training_options.json")
        if os.path.exists(config_fn):
            # Load config form the experiment existing file (and so ignore command-line arguments)
            with open(config_fn, "rt") as f:
                config = json.load(f)
            return config
        misc.log(
            f"Warning: --reload is set for a new experiment {args.expname}," +
            f" but configuration file to reload from {config_fn} doesn't exist.",
            "red")

    # GANformer and baselines default settings
    # ----------------------------------------------------------------------------

    if args.ganformer_default:
        task = args.dataset
        nset(args, "mirror_augment", task in ["cityscapes", "ffhq"])

        nset(args, "transformer", True)
        nset(args, "components_num", {"clevr": 8}.get(task, 16))
        nset(args, "latent_size", {"clevr": 128}.get(task, 512))

        nset(args, "normalize", "layer")
        nset(args, "integration", "mul")
        nset(args, "kmeans", True)
        nset(args, "use_pos", True)
        nset(args, "mapping_ltnt2ltnt", task != "clevr")
        nset(args, "style", task != "clevr")

        nset(args, "g_arch", "resnet")
        nset(args, "mapping_resnet", True)

        gammas = {"ffhq": 10, "cityscapes": 20, "clevr": 40, "bedrooms": 100}
        nset(args, "gamma", gammas.get(task, 10))

    if args.baseline == "GAN":
        nset(args, "style", False)
        nset(args, "latent_stem", True)

    ## k-GAN and SAGAN  are not currently supported in the pytorch version.
    ## See the TF version for implementation of these baselines!
    # if args.baseline == "SAGAN":
    #     nset(args, "style", False)
    #     nset(args, "latent_stem", True)
    #     nset(args, "g_img2img", 5)

    # if args.baseline == "kGAN":
    #     nset(args, "kgan", True)
    #     nset(args, "merge_layer", 5)
    #     nset(args, "merge_type", "softmax")
    #     nset(args, "components_num", 8)

    # General setup
    # ----------------------------------------------------------------------------

    # If the flag is specified without arguments (--arg), set to True
    for arg in [
            "cuda_bench", "allow_tf32", "keep_samples", "style", "local_noise"
    ]:
        if args[arg] is None:
            args[arg] = True

    if not any([args.train, args.eval, args.vis]):
        misc.log(
            "Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes",
            "red")
    for arg in ["train", "eval", "vis", "last_snapshots"]:
        cset(train, arg, args[arg])

    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    if not (num_gpus >= 1 and num_gpus & (num_gpus - 1) == 0):
        misc.error("Number of GPUs must be a power of two")
    args.num_gpus = num_gpus

    # CUDA settings
    for arg in ["batch_size", "batch_gpu", "allow_tf32"]:
        cset(train, arg, args[arg])
    cset(train, "cudnn_benchmark", args.cuda_bench)

    # Data setup
    # ----------------------------------------------------------------------------

    # For bedrooms, we choose the most common ratio in the
    # dataset and crop the other images into that ratio.
    ratios = {
        "clevr": 0.75,
        "bedrooms": 188 / 256,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = args.ratio or ratios.get(args.dataset, 1.0)
    args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio < 0.5 else None

    args.printname = args.expname
    for arg in ["total_kimg", "printname"]:
        cset(train, arg, args[arg])

    dataset_args = EasyDict(class_name="training.dataset.ImageFolderDataset",
                            path=f"{args.data_dir}/{args.dataset}",
                            max_items=args.train_images_num,
                            resolution=args.resolution,
                            ratio=args.ratio,
                            mirror_augment=args.mirror_augment)
    dataset_args.loader_args = EasyDict(num_workers=args.num_threads,
                                        pin_memory=True,
                                        prefetch_factor=2)

    # Optimization setup
    # ----------------------------------------------------------------------------

    cG = set_net("Generator", ["mapping", "synthesis"], args.g_lr, 4)
    cD = set_net("Discriminator", ["mapping", "block", "epilogue"], args.d_lr,
                 16)
    cset([cG, cD], "crop_ratio", args.crop_ratio)

    mbstd = min(
        args.batch_gpu, 4
    )  # other hyperparams behave more predictably if mbstd group size remains fixed
    cset(cD.epilogue_kwargs, "mbstd_group_size", mbstd)

    # Automatic tuning
    if args.autotune:
        batch_size = max(
            min(args.num_gpus * min(4096 // args.resolution, 32), 64),
            args.num_gpus)  # keep gpu memory consumption at bay
        batch_gpu = args.batch_size // args.num_gpus
        nset(args, "batch_size", batch_size)
        nset(args, "batch_gpu", batch_gpu)

        fmap_decay = 1 if args.resolution >= 512 else 0.5  # other hyperparams behave more predictably if mbstd group size remains fixed
        lr = 0.002 if args.resolution >= 1024 else 0.0025
        gamma = 0.0002 * (args.resolution**
                          2) / args.batch_size  # heuristic formula

        cset([cG.synthesis_kwargs, cD], "dim_base", int(fmap_decay * 32768))
        nset(args, "g_lr", lr)
        cset(cG.opt_args, "lr", args.g_lr)
        nset(args, "d_lr", lr)
        cset(cD.opt_args, "lr", args.d_lr)
        nset(args, "gamma", gamma)

        train.ema_rampup = 0.05
        train.ema_kimg = batch_size * 10 / 32

    if args.batch_size % (args.batch_gpu * args.num_gpus) != 0:
        misc.error(
            "--batch-size should be divided by --batch-gpu * 'num_gpus'")

    # Loss and regularization settings
    loss_args = EasyDict(class_name="training.loss.StyleGAN2Loss",
                         g_loss=args.g_loss,
                         d_loss=args.d_loss,
                         r1_gamma=args.gamma,
                         pl_weight=args.pl_weight)

    # if args.fp16:
    #     cset([cG.synthesis_kwargs, cD], "num_fp16_layers", 4) # enable mixed-precision training
    #     cset([cG.synthesis_kwargs, cD], "conv_clamp", 256) # clamp activations to avoid float16 overflow

    # cset([cG.synthesis_kwargs, cD.block_args], "fp16_channels_last", args.nhwc)

    # Evaluation and visualization
    # ----------------------------------------------------------------------------

    from metrics import metric_main
    for metric in args.metrics:
        if not metric_main.is_valid_metric(metric):
            misc.error(
                f"Unknown metric: {metric}. The valid metrics are: {metric_main.list_valid_metrics()}"
            )

    for arg in ["num_gpus", "metrics", "eval_images_num", "truncation_psi"]:
        cset(train, arg, args[arg])
    for arg in ["keep_samples", "num_heads"]:
        cset(vis, arg, args[arg])

    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = list({arg for arg in vis_types if args[f"vis_{arg}"]})

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "interpolation_density",
        # "intrp_per_component": "interpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks setup
    # ----------------------------------------------------------------------------

    # Networks architecture
    cset(cG.synthesis_kwargs, "architecture", args.g_arch)
    cset(cD, "architecture", args.d_arch)

    # Latent sizes
    if args.components_num > 0:
        if not args.transformer:  # or args.kgan):
            misc.error(
                "--components-num > 0 but the model is not using components. "
                +
                "Add --transformer for GANformer (which uses latent components)."
            )
        if args.latent_size % args.components_num != 0:
            misc.error(
                f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})"
            )
        args.latent_size = int(args.latent_size / args.components_num)

    cG.z_dim = cG.w_dim = args.latent_size
    cset([cG, vis], "k", args.components_num +
         1)  # We add a component to modulate features globally

    # Mapping network
    args.mapping_layer_dim = args.mapping_dim
    for arg in ["num_layers", "layer_dim", "resnet", "shared", "ltnt2ltnt"]:
        field = f"mapping_{arg}"
        cset(cG.mapping_kwargs, arg, args[field])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "local_noise"]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # GANformer
    cset([cG.synthesis_kwargs, cG.mapping_kwargs], "transformer",
         args.transformer)

    # Attention related settings
    for arg in ["use_pos", "num_heads", "ltnt_gate", "attention_dropout"]:
        cset([cG.mapping_kwargs, cG.synthesis_kwargs], arg, args[arg])

    # Attention types and layers
    for arg in ["start_res", "end_res"
                ]:  # , "local_attention" , "ltnt2ltnt", "img2img", "img2ltnt"
        cset(cG.synthesis_kwargs, arg, args[f"g_{arg}"])

    # Mixing and dropout
    for arg in ["style_mixing", "component_mixing"]:
        cset(loss_args, arg, args[arg])
    cset(cG, "component_dropout", args["component_dropout"])

    # Extra transformer options
    args.norm = args.normalize
    for arg in [
            "norm", "integration", "img_gate", "iterative", "kmeans",
            "kmeans_iters"
    ]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # Positional encoding
    # args.pos_dim = args.pos_dim or args.latent_size
    for arg in ["dim", "type", "init", "directions_num"]:
        field = f"pos_{arg}"
        cset(cG.synthesis_kwargs, field, args[field])

    # k-GAN
    # for arg in ["layer", "type", "same"]:
    #     field = "merge_{}".format(arg)
    #     cset(cG.args, field, args[field])
    # cset(cG.synthesis_kwargs, "merge", args.kgan)
    # if args.kgan and args.transformer:
    # misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both")

    config = EasyDict(train)
    config.update(cG=cG,
                  cD=cD,
                  loss_args=loss_args,
                  dataset_args=dataset_args,
                  vis_args=vis)

    # Save config file
    with open(os.path.join(run_dir, "training_options.json"), "wt") as f:
        json.dump(config, f, indent=2)

    return config