def setup_training_stages(loss_args, G, cG, D, cD, ddp_nets, device, log): misc.log("Setting up training stages...", "white", log) loss = dnnlib.util.construct_class_by_name(device = device, **ddp_nets, **loss_args) # subclass of training.loss.Loss stages = [] for name, net, config in [("G", G, cG), ("D", D, cD)]: if config.reg_interval is None: opt = dnnlib.util.construct_class_by_name(params = net.parameters(), **config.opt_args) # subclass of torch.optimOptimizer stages.append(dnnlib.EasyDict(name = f"{name}_both", net = net, opt = opt, interval = 1)) # Lazy regularization else: mb_ratio = config.reg_interval / (config.reg_interval + 1) opt_args = dnnlib.EasyDict(config.opt_args) opt_args.lr = opt_args.lr * mb_ratio opt_args.betas = [beta ** mb_ratio for beta in opt_args.betas] opt = dnnlib.util.construct_class_by_name(net.parameters(), **opt_args) # subclass of torch.optimOptimizer stages.append(dnnlib.EasyDict(name = f"{name}_main", net = net, opt = opt, interval = 1)) stages.append(dnnlib.EasyDict(name = f"{name}_reg", net = net, opt = opt, interval = config.reg_interval)) for stage in stages: stage.start_event = None stage.end_event = None if log: stage.start_event = torch.cuda.Event(enable_timing = True) stage.end_event = torch.cuda.Event(enable_timing = True) return loss, stages
def construct_nets(cG, cD, dataset, device, log): misc.log("Constructing networks...", "white", log) common_kwargs = dict(c_dim = dataset.label_dim, img_resolution = dataset.resolution, img_channels = dataset.num_channels) G = dnnlib.util.construct_class_by_name(**cG, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nnnet D = dnnlib.util.construct_class_by_name(**cD, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nnnet Gs = copy.deepcopy(G).eval() return G, D, Gs
def load_nets(resume_pkl, lG, lD, lGs, recompile): misc.log("Loading networks from %s..." % resume_pkl, "white") rG, rD, rGs = pretrained_networks.load_networks(resume_pkl) if recompile: misc.log("Copying nets...") lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs) else: lG, lD, lGs = rG, rD, rGs return lG, lD, lGs
def distribute_nets(G, D, Gs, device, num_gpus, log): misc.log(f"Distributing across {num_gpus} GPUs...", "white", log) networks = {} for name, net in [("G", G), ("D", D), (None, Gs)]: # ("G_mapping", G.mapping), ("G_synthesis", G.synthesis) if (num_gpus > 1) and (net is not None) and len(list(net.parameters())) != 0: net.requires_grad_(True) net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [device], broadcast_buffers = False, find_unused_parameters = True) net.requires_grad_(False) if name is not None: networks[name] = net return networks
def load_nets(load_pkl, nets, device, log): if (load_pkl is not None) and log: misc.log(f"Resuming from {load_pkl}", "white", log) resume_data = loader.load_network(load_pkl) if nets is not None: G, D, Gs = nets for name, net in [("G", G), ("D", D), ("Gs", Gs)]: torch_misc.copy_params_and_buffers(resume_data[name], net, require_all = False) else: for net in ["G", "D", "Gs"]: resume_data[net] = copy.deepcopy(resume_data[net]).eval().requires_grad_(False).to(device) nets = (resume_data["G"], resume_data["D"], resume_data["Gs"]) return nets
def load_dataset(dataset_args, batch_size, rank, num_gpus, log): misc.log("Loading training set...", "white", log) dataset = dnnlib.util.construct_class_by_name(**dataset_args) # subclass of training.datasetDataset dataset_sampler = torch_misc.InfiniteSampler(dataset = dataset, rank = rank, num_replicas = num_gpus) dataset_iter = iter(torch.utils.data.DataLoader(dataset = dataset, sampler = dataset_sampler, batch_size = batch_size//num_gpus, **dataset_args.loader_args)) misc.log(f"Num images: {misc.bcolored(len(dataset), 'blue')}", log = log) misc.log(f"Image shape: {misc.bcolored(dataset.image_shape, 'blue')}", log = log) misc.log(f"Label shape: {misc.bcolored(dataset.label_shape, 'blue')}", log = log) return dataset, dataset_iter
def run_cmdline(argv): parser = argparse.ArgumentParser(prog = argv[0], description = "Download and prepare data for the GANformer.") parser.add_argument("--data-dir", help = "Directory of created dataset", default = "datasets", type = str) parser.add_argument("--max-images", help = "Maximum number of images to have in the dataset (optional).", default = None, type = int) # Default tasks parser.add_argument("--clevr", help = "Prepare the CLEVR dataset (6.41GB download, 100k images)", dest = "tasks", action = "append_const", const = "clevr") parser.add_argument("--bedrooms", help = "Prepare the LSUN-bedrooms dataset (42.8GB download, 3M images)", dest = "tasks", action = "append_const", const = "bedrooms") parser.add_argument("--ffhq", help = "Prepare the FFHQ dataset (13GB download, 70k images)", dest = "tasks", action = "append_const", const = "ffhq") parser.add_argument("--cityscapes", help = "Prepare the cityscapes dataset (1.8GB download, 25k images)", dest = "tasks", action = "append_const", const = "cityscapes") # Create a new task with custom images parser.add_argument("--task", help = "New dataset name", type = str, dest = "tasks", action = "append") parser.add_argument("--images-dir", help = "Provide source image directory/file to convert into png-directory dataset (saves varied image resolutions)", default = None, type = str) parser.add_argument("--format", help = "Images format", default = None, choices = ["png", "jpg", "npy", "hdf5", "tfds", "lmdb", "tfrecords"], type = str) parser.add_argument("--ratio", help = "Images height/width", default = 1.0, type = float) args = parser.parse_args() if not args.tasks: misc.error("No tasks specified. Please see '-h' for help.") if args.max_images is not None and args.max_images < 50000: misc.log(f"Warning: max-images is set to {args.max_images}. We recommend setting it at least to 50,000 to allow statistically correct computation of the FID-50k metric.", "red") prepare(**vars(args))
def setup_savefile(args, run_name, run_dir, config): snapshot, kimg, resume = None, 0, False pkls = sorted(glob.glob(f"{run_dir}/network*.pkl")) # Load a particular snapshot is specified if args.pretrained_pkl is not None and args.pretrained_pkl != "None": # Soft links support if args.pretrained_pkl.startswith("gdrive"): if args.pretrained_pkl not in loader.pretrained_networks: misc.error( "--pretrained_pkl {} not available in the catalog (see loader.pretrained_networks dict)" ) snapshot = args.pretrained_pkl else: snapshot = glob.glob(args.pretrained_pkl)[0] if os.path.islink(snapshot): snapshot = os.readlink(snapshot) # Extract training step from the snapshot if specified try: kimg = int(snapshot.split("-")[-1].split(".")[0]) except: pass # Find latest snapshot in the directory elif len(pkls) > 0: snapshot = pkls[-1] kimg = int(snapshot.split("-")[-1].split(".")[0]) resume = True if snapshot: misc.log(f"Resuming {run_name}, from {snapshot}, kimg {kimg}", "white") config.resume_pkl = snapshot config.resume_kimg = kimg else: misc.log("Start model training from scratch", "white")
def run(**args): args = EasyDict(args) train = EasyDict(run_func_name="training.training_loop.training_loop" ) # training loop options sched = EasyDict() # TrainingSchedule options vis = EasyDict() # visualize.eval() options grid = EasyDict(size="1080p", layout="random") # setup_snapshot_img_grid() options sc = dnnlib.SubmitConfig() # dnnlib.submit_run() options # If the flag is specified without arguments (--arg), set to True for arg in [ "summarize", "keep_samples", "style", "fused_modconv", "local_noise" ]: if args[arg] is None: args[arg] = True if not args.train and not args.eval: misc.log( "Warning: Neither --train nor --eval are provided. Therefore, we only print network shapes", "red") if args.gansformer_default: task = args.dataset pretrained = "gdrive:{}-snapshot.pkl".format(task) if pretrained not in pretrained_networks.gdrive_urls: pretrained = None nset(args, "recompile", pretrained is not None) nset(args, "pretrained_pkl", pretrained) nset(args, "mirror_augment", task in ["cityscapes", "ffhq"]) nset(args, "transformer", True) nset(args, "components_num", {"clevr": 8}.get(task, 16)) nset(args, "latent_size", {"clevr": 128}.get(task, 512)) nset(args, "normalize", "layer") nset(args, "integration", "mul") nset(args, "kmeans", True) nset(args, "use_pos", True) nset(args, "mapping_ltnt2ltnt", task != "clevr") nset(args, "style", task != "clevr") nset(args, "g_arch", "resnet") nset(args, "mapping_resnet", True) gammas = {"ffhq": 10, "cities": 20, "clevr": 40, "bedrooms": 100} nset(args, "gamma", gammas.get(task, 10)) if args.baseline == "GAN": nset(args, "style", False) nset(args, "latent_stem", True) if args.baseline == "SAGAN": nset(args, "style", False) nset(args, "latent_stem", True) nset(args, "g_img2img", 5) if args.baseline == "kGAN": nset(args, "kgan", True) nset(args, "merge_layer", 5) nset(args, "merge_type", "softmax") nset(args, "components_num", 8) # Environment configuration tf_config = { "rnd.np_random_seed": 1000, "allow_soft_placement": True, "gpu_options.per_process_gpu_memory_fraction": 1.0 } if args.gpus != "": num_gpus = len(args.gpus.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus assert num_gpus in [1, 2, 4, 8] sc.num_gpus = num_gpus # Networks configuration cG = set_net("G", reg_interval=4) cD = set_net("D", reg_interval=16) # Dataset configuration # For bedrooms, we choose the most common ratio in the # dataset and crop the other images into that ratio. ratios = { "clevr": 0.75, "bedrooms": 188 / 256, "cityscapes": 0.5, "ffhq": 1.0 } args.ratio = ratios.get(args.dataset, args.ratio) dataset_args = EasyDict(tfrecord_dir=args.dataset, max_imgs=args.train_images_num, num_threads=args.num_threads) for arg in ["data_dir", "mirror_augment", "total_kimg", "ratio"]: cset(train, arg, args[arg]) # Training and Optimizations configuration for arg in ["eval", "train", "recompile", "last_snapshots"]: cset(train, arg, args[arg]) # Round to the closest multiply of minibatch size for validity args.batch_size -= args.batch_size % args.minibatch_size args.minibatch_std_size -= args.minibatch_std_size % args.minibatch_size args.latent_size -= args.latent_size % args.components_num if args.latent_size == 0: misc.error( "--latent-size is too small. Must best a multiply of components-num" ) sched_args = { "G_lrate": "g_lr", "D_lrate": "d_lr", "minibatch_size": "batch_size", "minibatch_gpu": "minibatch_size" } for arg, cmd_arg in sched_args.items(): cset(sched, arg, args[cmd_arg]) cset(train, "clip", args.clip) # Logging and metrics configuration metrics = [metric_defaults[x] for x in args.metrics] cset(cG.args, "truncation_psi", args.truncation_psi) for arg in ["keep_samples", "num_heads"]: cset(vis, arg, args[arg]) for arg in ["summarize", "eval_images_num"]: cset(train, arg, args[arg]) # Visualization args.vis_imgs = args.vis_images args.vis_ltnts = args.vis_latents vis_types = [ "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var", "style_mix" ] # Set of all the set visualization types option vis.vis_types = {arg for arg in vis_types if args["vis_{}".format(arg)]} vis_args = { "attention": "transformer", "grid": "vis_grid", "num": "vis_num", "rich_num": "vis_rich_num", "section_size": "vis_section_size", "intrp_density": "interpolation_density", # "intrp_per_component": "interpolation_per_component", "alpha": "blending_alpha" } for arg, cmd_arg in vis_args.items(): cset(vis, arg, args[cmd_arg]) # Networks architecture cset(cG.args, "architecture", args.g_arch) cset(cD.args, "architecture", args.d_arch) cset(cG.args, "tanh", args.tanh) # Latent sizes if args.components_num > 1: if not (args.transformer or args.kgan): misc.error( "--components-num > 1 but the model is not using components. " + "Either add --transformer for GANsformer or --kgan for k-GAN.") args.latent_size = int(args.latent_size / args.components_num) cD.args.latent_size = cG.args.latent_size = cG.args.dlatent_size = args.latent_size cset([cG.args, cD.args, vis], "components_num", args.components_num) # Mapping network for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]: field = "mapping_{}".format(arg) cset(cG.args, field, args[field]) # StyleGAN settings for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]: cset(cG.args, arg, args[arg]) cD.args.mbstd_group_size = args.minibatch_std_size # GANsformer cset(cG.args, "transformer", args.transformer) cset(cD.args, "transformer", args.d_transformer) args.norm = args.normalize for arg in [ "norm", "integration", "ltnt_gate", "img_gate", "iterative", "kmeans", "kmeans_iters", "mapping_ltnt2ltnt" ]: cset(cG.args, arg, args[arg]) for arg in ["use_pos", "num_heads"]: cset([cG.args, cD.args], arg, args[arg]) # Positional encoding for arg in ["dim", "init", "directions_num"]: field = "pos_{}".format(arg) cset([cG.args, cD.args], field, args[field]) # k-GAN for arg in ["layer", "type", "same"]: field = "merge_{}".format(arg) cset(cG.args, field, args[field]) cset([cG.args, train], "merge", args.kgan) if args.kgan and args.transformer: misc.error( "Either have --transformer for GANsformer or --kgan for k-GAN, not both" ) # Attention for arg in ["start_res", "end_res", "ltnt2ltnt", "img2img"]: # , "local_attention" cset(cG.args, arg, args["g_{}".format(arg)]) cset(cD.args, arg, args["d_{}".format(arg)]) cset(cG.args, "img2ltnt", args.g_img2ltnt) # cset(cD.args, "ltnt2img", args.d_ltnt2img) # Mixing and dropout for arg in [ "style_mixing", "component_mixing", "component_dropout", "attention_dropout" ]: cset(cG.args, arg, args[arg]) # Loss and regularization gloss_args = { "loss_type": "g_loss", "reg_weight": "g_reg_weight", # "pathreg": "pathreg", } dloss_args = {"loss_type": "d_loss", "reg_type": "d_reg", "gamma": "gamma"} for arg, cmd_arg in gloss_args.items(): cset(cG.loss_args, arg, args[cmd_arg]) for arg, cmd_arg in dloss_args.items(): cset(cD.loss_args, arg, args[cmd_arg]) ##### Experiments management: # Whenever we start a new experiment we store its result in a directory named 'args.expname:000'. # When we rerun a training or evaluation command it restores the model from that directory by default. # If we wish to restart the model training, we can set --restart and then we will store data in a new # directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc. # Find the latest directory that matches the experiment exp_dir = sorted(glob.glob("{}/{}-*".format(args.result_dir, args.expname))) run_id = 0 if len(exp_dir) > 0: run_id = int(exp_dir[-1].split("-")[-1]) # If restart, then work over a new directory if args.restart: run_id += 1 run_name = "{}-{:03d}".format(args.expname, run_id) train.printname = "{} ".format(misc.bold(args.expname)) snapshot, kimg, resume = None, 0, False pkls = sorted( glob.glob("{}/{}/network*.pkl".format(args.result_dir, run_name))) # Load a particular snapshot is specified if args.pretrained_pkl is not None and args.pretrained_pkl != "None": # Soft links support if args.pretrained_pkl.startswith("gdrive"): if args.pretrained_pkl not in pretrained_networks.gdrive_urls: misc.error( "--pretrained_pkl {} not available in the catalog (see pretrained_networks.py)" ) snapshot = args.pretrained_pkl else: snapshot = glob.glob(args.pretrained_pkl)[0] if os.path.islink(snapshot): snapshot = os.readlink(snapshot) # Extract training step from the snapshot if specified try: kimg = int(snapshot.split("-")[-1].split(".")[0]) except: pass # Find latest snapshot in the directory elif len(pkls) > 0: snapshot = pkls[-1] kimg = int(snapshot.split("-")[-1].split(".")[0]) resume = True if snapshot: misc.log( "Resuming {}, from {}, kimg {}".format(run_name, snapshot, kimg), "white") train.resume_pkl = snapshot train.resume_kimg = kimg else: misc.log("Start model training from scratch", "white") # Run environment configuration sc.run_dir_root = args.result_dir sc.run_desc = args.expname sc.run_id = run_id sc.run_name = run_name sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(cG=cG, cD=cD) kwargs.update(dataset_args=dataset_args, vis_args=vis, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) kwargs.submit_config = copy.deepcopy(sc) kwargs.resume = resume kwargs.load_config = args.reload dnnlib.submit_run(**kwargs)
def training_loop( # General configuration train = False, # Training mode eval = False, # Evaluation mode vis = False, # Visualization mode run_dir = ".", # Output directory num_gpus = 1, # Number of GPUs participating in the training rank = 0, # Rank of the current process in [0, num_gpus] cG = {}, # Options for generator network cD = {}, # Options for discriminator network # Data dataset_args = {}, # Options for training set drange_net = [-1,1], # Dynamic range used when feeding image data to the networks # Optimization loss_args = {}, # Options for loss function total_kimg = 25000, # Total length of the training, measured in thousands of real images batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus batch_gpu = 4, # Number of samples processed at a time by one GPU ema_kimg = 10.0, # Half-life of the exponential moving average (EMA) of generator weights ema_rampup = None, # EMA ramp-up coefficient cudnn_benchmark = True, # Enable torch.backends.cudnnbenchmark? allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnnallow_tf32? # Logging resume_pkl = None, # Network pickle to resume training from resume_kimg = 0.0, # Assumed training progress at the beginning # Affects reporting and training schedule kimg_per_tick = 8, # Progress snapshot interval img_snapshot_ticks = 3, # How often to save image snapshots? None = disable network_snapshot_ticks = 3, # How often to save network snapshots? None = disable last_snapshots = 10, # Maximal number of prior snapshots to save printname = "", # Experiment name for logging # Evaluation vis_args = {}, # Options for vis.vis metrics = [], # Metrics to evaluate during training eval_images_num = 50000, # Sample size for the metrics truncation_psi = 0.7 # Style strength multiplier for the truncation trick (used for visualizations only) ): # Initialize start_time = time.time() device = init_cuda(rank, cudnn_benchmark, allow_tf32) log = (rank == 0) dataset, dataset_iter = load_dataset(dataset_args, batch_size, rank, num_gpus, log) # Load training set nets = construct_nets(cG, cD, dataset, device, log) if train else None # Construct networks G, D, Gs = load_nets(resume_pkl, nets, device, log) # Resume from existing pickle print_nets(G, D, batch_gpu, device, log) # Print network summary tables if eval: misc.log("Run evaluation...", log = log) evaluate(Gs, resume_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log) if vis and log: misc.log("Produce visualizations...") visualize.vis(Gs, dataset, device, batch_gpu, drange_net = drange_net, ratio = dataset.ratio, truncation_psi = truncation_psi, **vis_args) if not train: exit() nets = distribute_nets(G, D, Gs, device, num_gpus, log) # Distribute networks across GPUs loss, stages = setup_training_stages(loss_args, G, cG, D, cD, nets, device, log) # Setup training stages (losses and optimizers) grid_size, grid_z, grid_c = init_img_grid(dataset, G.input_shape, device, run_dir, log) # Initialize an image grid logger = init_logger(run_dir, log) # Initialize logs # Train misc.log(f"Training for {total_kimg} kimg...", "white", log) cur_nimg, cur_tick, batch_idx = int(resume_kimg * 1000), 0, 0 tick_start_nimg, tick_start_time = cur_nimg, time.time() stats = None while True: # Fetch training data real_img, real_c, gen_zs, gen_cs = fetch_data(dataset, dataset_iter, G.input_shape, drange_net, device, len(stages), batch_size, batch_gpu) # Execute training stages for stage, gen_z, gen_c in zip(stages, gen_zs, gen_cs): if batch_idx % stage.interval != 0: continue run_training_stage(loss, stage, device, real_img, real_c, gen_z, gen_c, batch_size, batch_gpu, num_gpus) # Update Gs update_ema_network(G, Gs, batch_size, cur_nimg, ema_kimg, ema_rampup) # Update state cur_nimg += batch_size batch_idx += 1 # Perform maintenance tasks once per tick done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line and accumulate the info in logger.collector tick_end_time = time.time() if stats is not None: default = dnnlib.EasyDict({'mean': -1}) fields = [] fields.append("tick " + misc.bold(f"{training_stats.report0('Progress/tick', cur_tick):<5d}")) fields.append("kimg " + misc.bcolored(f"{training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}", "red")) fields.append("") fields.append("loss/reg: G (" + misc.bcolored(f"{stats.get('Loss/G/loss', default).mean:>6.3f}", "blue")) fields.append(misc.bold(f"{stats.get('Loss/G/reg', default).mean:>6.3f}") + ")") fields.append("D "+ misc.bcolored(f"({stats.get('Loss/D/loss', default).mean:>6.3f}", "blue")) fields.append(misc.bold(f"{stats.get('Loss/D/reg', default).mean:>6.3f}") + ")") fields.append("") fields.append("time " + misc.bold(f"{dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}")) fields.append(f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}") fields.append(f"mem: GPU {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}") fields.append(f"CPU {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}") fields.append(misc.bold(printname)) torch.cuda.reset_peak_memory_stats() training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) misc.log(" ".join(fields), log = log) # Save image snapshot if log and (img_snapshot_ticks is not None) and (done or cur_tick % img_snapshot_ticks == 0): visualize.vis(Gs, dataset, device, batch_gpu, training = True, step = cur_nimg // 1000, grid_size = grid_size, latents = grid_z, labels = grid_c, drange_net = drange_net, ratio = dataset.ratio, **vis_args) # Save network snapshot if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data, snapshot_pkl = save_nets(G, D, Gs, cur_nimg, dataset_args, run_dir, num_gpus > 1, last_snapshots, log) # Evaluate metrics evaluate(snapshot_data["Gs"], snapshot_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log, logger, run_dir) del snapshot_data # Collect stats and update logs stats = collect_stats(logger, stages) update_logger(logger, stats, cur_nimg, start_time) cur_tick += 1 tick_start_nimg, tick_start_time = cur_nimg, time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done misc.log("Done!", "blue")
def run_cmdline(argv): parser = argparse.ArgumentParser( prog=argv[0], description="Download and prepare data for the GANsformer.") parser.add_argument("--data-dir", help="Directory of created dataset", default="datasets", type=str) parser.add_argument( "--shards-num", help="Number of shards to split each dataset to (optional)", default=1, type=int) parser.add_argument( "--max-images", help= "Maximum number of images to have in the dataset (optional). Use to reduce the produced tfrecords file size", default=None, type=int) # Default tasks parser.add_argument( "--clevr", help= "Prepare the CLEVR dataset (18GB download, up to 15.5GB tfrecords, 100k images)", dest="tasks", action="append_const", const="clevr") parser.add_argument( "--bedrooms", help= "Prepare the LSUN-bedrooms dataset (42.8GB, up to 480GB tfrecords, 3M images)", dest="tasks", action="append_const", const="bedrooms") parser.add_argument( "--ffhq", help= "Prepare the FFHQ dataset (13GB download, 13GB tfrecords, 70k images)", dest="tasks", action="append_const", const="ffhq") parser.add_argument( "--cityscapes", help= "Prepare the cityscapes dataset (1.8GB, 8GB tfrecords, 25k images)", dest="tasks", action="append_const", const="cityscapes") # Create a new task with custom images parser.add_argument("--task", help="New dataset name", type=str, dest="tasks", action="append") parser.add_argument( "--images-dir", help= "Provide source image directory to convert into tfrecords (will be searched recursively)", default=None, type=str) parser.add_argument("--format", help="Images format", default=None, choices=["png", "jpg", "npy", "hdf5", "tfds", "lmdb"], type=str) parser.add_argument("--ratio", help="Images height/width", default=1.0, type=float) args = parser.parse_args() if not args.tasks: misc.error("No tasks specified. Please see '-h' for help.") if args.max_images < 50000: misc.log( "Warning: max-images is set to {}. We recommend setting it at least to 50,000 to allow statistically correct computation of the FID-50k metric." .format(args.max_images), "red") prepare(**vars(args))
def training_loop( # Configurations cG={}, cD={}, # Generator and Discriminator command-line arguments dataset_args={}, # dataset.load_dataset() options sched_args={}, # train.TrainingSchedule options vis_args={}, # vis.eval options grid_args={}, # train.setup_snapshot_img_grid() options metric_arg_list=[], # MetricGroup Options tf_config={}, # tflib.init_tf() options eval=False, # Evaluation mode train=False, # Training mode # Data data_dir=None, # Directory to load datasets from total_kimg=25000, # Total length of the training, measured in thousands of real images mirror_augment=False, # Enable mirror augmentation? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks ratio=1.0, # Image height/width ratio in the dataset # Optimization minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters lazy_regularization=True, # Perform regularization as a separate training step? smoothing_kimg=10.0, # Half-life of the running average of generator weights clip=None, # Clip gradients threshold # Resumption resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning # Affects reporting and training schedule resume_time=0.0, # Assumed wallclock time at the beginning, affects reporting recompile=False, # Recompile network from source code (otherwise loads from snapshot) # Logging summarize=True, # Create TensorBoard summaries save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? img_snapshot_ticks=3, # How often to save image snapshots? None = disable network_snapshot_ticks=3, # How often to save network snapshots? None = only save networks-final.pkl last_snapshots=10, # Maximal number of prior snapshots to save eval_images_num=50000, # Sample size for the metrics printname="", # Experiment name for logging # Architecture merge=False): # Generate several images and then merge them # Initialize dnnlib and TensorFlow tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus cG.name, cD.name = "g", "d" # Load dataset, configure training scheduler and metrics object dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) sched = training_schedule(sched_args, cur_nimg=total_kimg * 1000, dataset=dataset) metrics = metric_base.MetricGroup(metric_arg_list) # Construct or load networks with tf.device("/gpu:0"): no_op = tf.no_op() G, D, Gs = None, None, None if resume_pkl is None or recompile: misc.log("Constructing networks...", "white") G = tflib.Network("G", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cG.args) D = tflib.Network("D", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cD.args) Gs = G.clone("Gs") if resume_pkl is not None: G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile) G.print_layers() D.print_layers() # Train/Evaluate/Visualize # Labels are optional but not essential grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid( dataset, **grid_args) misc.save_img_grid(grid_reals, dnnlib.make_run_dir_path("reals.png"), drange=dataset.dynamic_range, grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) if eval: # Save a snapshot of the current network to evaluate pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" % resume_kimg) misc.save_pkl((G, D, Gs), pkl, remove=False) # Quantitative evaluation metric = metrics.run(pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) # Qualitative evaluation visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, drange_net=drange_net, ratio=ratio, **vis_args) if not train: dataset.close() exit() # Setup training inputs misc.log("Building TensorFlow graph...", "white") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[]) lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[]) step = tf.placeholder(tf.int32, name="step", shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name="minibatch_size_in", shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name="minibatch_gpu_in", shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), smoothing_kimg * 1000.0) if smoothing_kimg > 0.0 else 0.0 # Set optimizers for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]: set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip) # Build training graph for each GPU data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): # Create GPU-specific shadow copies of G and D for cN, N in [(cG, G), (cD, D)]: cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow") Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow") # Fetch training data via temporary variables with tf.name_scope("DataFetch"): reals, labels = dataset.get_minibatch_tf() reals = process_reals(reals, dataset.dynamic_range, drange_net, mirror_augment) reals, reals_fetch = read_data( reals, "reals", [sched.minibatch_gpu] + dataset.shape, minibatch_gpu_in) labels, labels_fetch = read_data( labels, "labels", [sched.minibatch_gpu, dataset.label_size], minibatch_gpu_in) data_fetch_ops += [reals_fetch, labels_fetch] # Evaluate loss functions with tf.name_scope("G_loss"): cG.loss, cG.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, minibatch_size=minibatch_gpu_in, **cG.loss_args) with tf.name_scope("D_loss"): cD.loss, cD.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, labels=labels, minibatch_size=minibatch_gpu_in, **cD.loss_args) for cN in [cG, cD]: set_optimizer_ops(cN, lazy_regularization, no_op) # Setup training ops data_fetch_op = tf.group(*data_fetch_ops) for cN in [cG, cD]: cN.train_op = cN.opt.apply_updates() cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta) # Finalize graph with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # Tensorboard summaries if summarize: misc.log("Initializing logs...", "white") summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() # Initialize training misc.log("Training for %d kimg..." % total_kimg, "white") dnnlib.RunContext.get().update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_tick, running_mb_counter = -1, 0 cur_nimg = int(resume_kimg * 1000) tick_start_nimg = cur_nimg for cN in [cG, cD]: cN.lossvals_agg = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } cN.opt.reset_optimizer_state() # Training loop while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops sched = training_schedule(sched_args, cur_nimg=cur_nimg, dataset=dataset) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 dataset.configure(sched.minibatch_gpu) # Run training ops feed_dict = { lrate_in_g: sched.G_lrate, lrate_in_d: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, step: sched.kimg } # Several iterations before updating training parameters for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) for cN in [cG, cD]: cN.run_reg = lazy_regularization and (running_mb_counter % cN.reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 for cN in [cG, cD]: cN.lossvals = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } # Gradient accumulation for _round in rounds: cG.lossvals.update( tflib.run([cG.train_op, cG.ops], feed_dict)[1]) if cG.run_reg: _, cG.lossvals["reg_norm"] = tflib.run( [cG.reg_op, cG.reg_norm], feed_dict) tflib.run(data_fetch_op, feed_dict) cD.lossvals.update( tflib.run([cD.train_op, cD.ops], feed_dict)[1]) if cD.run_reg: _, cD.lossvals["reg_norm"] = tflib.run( [cD.reg_op, cD.reg_norm], feed_dict) tflib.run([Gs_update_op], feed_dict) # Track loss statistics for cN in [cG, cD]: for k in cN.lossvals_agg: cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k], cN.lossvals[k]) # Perform maintenance tasks once per tick done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress print( ("tick %s kimg %s loss/reg: G (%s %s) D (%s %s) grad norms: G (%s %s) D (%s %s) " + "time %s sec/kimg %s maxGPU %sGB %s") % (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)), misc.bcolored( "{:>8.1f}".format( autosummary("Progress/kimg", cur_nimg / 1000.0)), "red"), misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)), misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)), misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"), misc.bold("%-10s" % dnnlib.util.format_time( autosummary("Timing/total_sec", total_time))), "{:>7.2f}".format( autosummary("Timing/sec_per_kimg", tick_time / tick_kimg)), "{:>4.1f}".format( autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30)), printname)) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots if img_snapshot_ticks is not None and ( cur_tick % img_snapshot_ticks == 0 or done): visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, training=True, step=cur_nimg // 1000, grid_size=grid_size, latents=grid_latents, labels=grid_labels, drange_net=drange_net, ratio=ratio, **vis_args) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl, remove=False) if cur_tick % network_snapshot_ticks == 0 or done: metric = metrics.run( pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) if last_snapshots > 0: misc.rm( sorted( glob.glob(dnnlib.make_run_dir_path( "network*.pkl")))[:-last_snapshots]) # Update summaries and RunContext if summarize: metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update(None, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl"), remove=False) # All done if summarize: summary_log.close() dataset.close()
def vis(G, dataset, # The dataset object for accessing the data device, # Device to run visualization on batch_size, # Visualization batch size run_dir = ".", # Output directory training = False, # Training mode latents = None, # Source latents to generate images from labels = None, # Source labels to generate images from (0 if no labels are used) ratio = 1.0, # Image height/width ratio in the dataset truncation_psi = 0.7, # Style strength multiplier for the truncation trick (used for visualizations only) # Model settings k = 1, # Number of components the model has drange_net = [-1,1], # Model image output range attention = False, # Whereas the model produces attention maps (for visualization) num_heads = 1, # Number of attention heads # Visualization settings vis_types = None, # Visualization types to be created num = 100, # Number of produced samples rich_num = 5, # Number of samples for which richer visualizations will be created # (requires more memory and disk space, and therefore rich_num <= num) grid = None, # Whether to save the samples in one large grid files # or in separated files one per sample grid_size = None, # Grid proportions (w, h) step = None, # Step number to be used in visualization filenames verbose = None, # Verbose print progress messages keep_samples = True, # Keep all prior samples during training # Visualization-specific settings alpha = 0.3, # Proportion for generated images and attention maps blends intrp_density = 8, # Number of samples in between two end points of an interpolation intrp_per_component = False, # Whether to perform interpolation along particular latent components (True) # or all of them at once (False) noise_samples_num = 100, # Number of samples used to compute noise variation visualization section_size = 100): # Visualization section size (section_size <= num) for reducing memory footprint def prefix(step): return "" if step is None else f"{step:06d}_" def pattern_of(dir, step, suffix): return f"{run_dir}/visuals/{dir}/{prefix(step)}%06d.{suffix}" # Set default options if verbose is None: verbose = not training # Disable verbose during training if grid is None: grid = training # Save image samples in one grid file during training if grid_size is not None: section_size = rich_num = num = np.prod(grid_size) # If grid size is provided, set images number accordingly _labels, _latents = labels, latents if _latents is not None: assert num == _latents.shape[0] if _labels is not None: assert num == _labels.shape[0] assert rich_num <= section_size vis = vis_types # For time efficiency, during training save only image and map samples rather than richer visualizations if training: vis = {"imgs"} # , "maps" # if num_heads == 1: # vis.add("layer_maps") else: vis = vis or {"imgs", "maps", "ltnts", "interpolations", "noise_var"} # Build utility functions save_images = misc.save_images_builder(drange_net, ratio, grid_size, grid, verbose) save_blends = misc.save_blends_builder(drange_net, ratio, grid_size, grid, verbose, alpha) crange = trange if verbose else range section_of = lambda a, i, n: a[i * n: (i + 1) * n] get_rnd_latents = lambda n: torch.randn([n, *G.input_shape[1:]], device = device) get_rnd_labels = lambda n: torch.from_numpy(dataset.get_random_labels(n)).to(device) # Create directories dirs = [] if "imgs" in vis: dirs += ["images"] if "ltnts" in vis: dirs += ["latents-z", "latents-w"] if "maps" in vis: dirs += ["maps", "softmaps", "blends", "softblends"] if "layer_maps" in vis: dirs += ["layer_maps"] if "interpolations" in vis: dirs += ["interpolations-z", "interpolation-w"] if not keep_samples: shutil.rmtree(f"{run_dir}/visuals") for dir in dirs: os.makedirs(f"{run_dir}/visuals/{dir}", exist_ok = True) if verbose: print("Running network and saving samples...") # Produce visualizations for idx in crange(0, num, section_size): curr_size = curr_section_size(num, idx, section_size) # Compute source latents/labels that images will be produced from latents = get_rnd_latents(curr_size) if _latents is None else section_of(_latents, idx, section_size) labels = get_rnd_labels(curr_size) if _labels is None else section_of(_labels, idx, section_size) if idx == 0: latents0, labels0 = latents, labels # Run network over latents and produce images and attention maps ret = run(G, latents, labels, batch_size, truncation_psi, noise_mode = "const", return_att = True, return_ws = True) # For memory efficiency, save full information only for a small amount of images images, attmaps_all_layers, wlatents_all_layers = ret soft_maps = attmaps_all_layers[:,:,-1,0] if attention else None attmaps_all_layers = attmaps_all_layers[:rich_num] wlatents = wlatents_all_layers[:,:,0] # Save image samples if "imgs" in vis: save_images(images, pattern_of("images", step, "png"), idx) # Save latent vectors if "ltnts" in vis: misc.save_npys(latents, pattern_of("latents-z", step, "npy"), verbose, idx) misc.save_npys(wlatents, pattern_of("latents-w", step, "npy"), verbose, idx) # For the GANformer model, save attention maps if attention: if "maps" in vis: pallete = np.expand_dims(misc.get_colors(k - 1), axis = [2, 3]) maps = (soft_maps == np.amax(soft_maps, axis = 1, keepdims = True)).astype(float) soft_maps = np.sum(pallete * np.expand_dims(soft_maps, axis = 2), axis = 1) maps = np.sum(pallete * np.expand_dims(maps, axis = 2), axis = 1) save_images(soft_maps, pattern_of("softmaps", step, "png"), idx) save_images(maps, pattern_of("maps", step, "png"), idx) save_blends(soft_maps, images, pattern_of("softblends", step, "png"), idx) save_blends(maps, images, pattern_of("blends", step, "png"), idx) # Save maps from all attention heads and layers # (for efficiency, only for a small number of images) if "layer_maps" in vis: all_maps = [] maps_fakes = np.split(attmaps_all_layers, attmaps_all_layers.shape[2], axis = 2) for layer, lmaps in enumerate(maps_fakes): lmaps = np.split(np.squeeze(lmaps, axis = 2), lmaps.shape[2], axis = 2) for head, hmap in enumerate(lmaps): hmap = (hmap == np.amax(hmap, axis = 1, keepdims = True)).astype(float) hmap = np.sum(pallete * hmap, axis = 1) all_maps.append((hmap, f"l{layer}_h{head}")) if not grid: for i in range(rich_num): stepdir = "" if step is None else (f"/{step:06d}") os.makedirs(f"{run_dir}/visuals/layer_maps/%06d" % i + stepdir, exist_ok = True) for maps, name in all_maps: if grid: pattern = f"{run_dir}/visuals/layer_maps/{prefix(step)}%06d-{name}.png" else: pattern = f"{run_dir}/visuals/layer_maps/%06d/{stepdir}/{name}.png" save_images(maps, pattern, idx) # Produce interpolations between pairs or source latents # In the GANformer case, varying one component at a time if "interpolations" in vis: ts = torch.linspace(0.0, 1.0, steps = intrp_density) if verbose: print("Generating interpolations...") for i in crange(rich_num): os.makedirs(f"{run_dir}/visuals/interpolations-z/%06d" % i, exist_ok = True) os.makedirs(f"{run_dir}/visuals/interpolations-w/%06d" % i, exist_ok = True) z = get_rnd_latents(2) z[0] = latents0[i] c = labels0[i:i+1] w = run(G, z, c, batch_size, truncation_psi, noise_mode = "const", return_ws = True)[-1] def update(t, fn, ts, dim): if dim == 3: ts = ts[:, None] t_ups = [] if intrp_per_component: for c in range(k - 1): # copy over all the components except component c that will get interpolated t_up = torch.clone(t[0]).unsqueeze(0).repeat((intrp_density, ) + tuple([1] * dim)) # interpolate component c t_up[:,c] = fn(t[0, c], t[1, c], ts) t_ups.append(t_up) t_up = torch.cat(t_ups, dim = 0) else: t_up = fn(t[0], t[1], ts.unsqueeze(1)) return t_up z_up = update(z, slerp, ts, 2) w_up = update(w, lerp, ts, 3) imgs1 = run(G, z_up, c, batch_size, truncation_psi, noise_mode = "const")[0] imgs2 = run(G, w_up, c, batch_size, truncation_psi, noise_mode = "const", take_w = True)[0] def save_interpolation(imgs, name): imgs = np.split(imgs, k - 1, axis = 0) for c in range(k - 1): filename = f"{run_dir}/visuals/interpolations-{name}/{i:06d}/{c:02d}" imgs[c] = [misc.to_pil(img, drange = drange_net) for img in imgs[c]] imgs[c][-1].save(f"{filename}.png") misc.save_gif(imgs[c], f"{filename}.gif") save_interpolation(imgs1, "z") save_interpolation(imgs2, "w") # Compute noise variance map # Shows what areas vary the most given fixed source # latents due to the use of stochastic local noise if "noise_var" in vis: if verbose: print("Generating noise variance...") z = get_rnd_latents(1).repeat(noise_samples_num, 1, 1) c = get_rnd_labels(1) imgs = run(G, z, c, batch_size, truncation_psi)[0] imgs = np.stack([misc.to_pil(img, drange = drange_net) for img in imgs], axis = 0) diff = np.std(np.mean(imgs, axis = 3), axis = 0) * 4 diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) PIL.Image.fromarray(diff, "L").save(f"{run_dir}/visuals/noise-variance.png") # Compute style mixing table, varying using the latent A in some of the layers and latent B in rest. # For the GANformer, also produce component mixes (using latents from A in some of the components, # and latents from B in the rest. if "style_mix" in vis: if verbose: print("Generating style mixes...") cols, rows = 4, 2 row_lens = np.array([2, 5, 8, 11]) c = get_rnd_labels(1) # Create latent mixes mixes = { "layer": (np.arange(wlatents_all_layers.shape[2]) < row_lens[:,None]).astype(np.float32)[:,None,None,None,:,None], "component": (np.arange(wlatents_all_layers.shape[1]) < row_lens[:,None]).astype(np.float32)[:,None,None,:,None,None] } ws = wlatents_all_layers[:cols+rows] orig_imgs = images[:cols+rows] col_z = wlatents_all_layers[:cols][None, None] row_z = wlatents_all_layers[cols:cols+rows][None,:,None] for name, mix in mixes.items(): # Produce image mixes mix_z = mix * row_z + (1 - mix) * col_z mix_z = torch.from_numpy(np.reshape(mix_z, [-1, *wlatents_all_layers.shape[1:]])).to(device) mix_imgs = run(G, mix_z, c, batch_size, truncation_psi, noise_mode = "const", take_w = True)[0] mix_imgs = np.reshape(mix_imgs, [len(row_lens) * rows, cols, *mix_imgs.shape[1:]]) # Create image table canvas H, W = mix_imgs.shape[-2:] canvas = PIL.Image.new("RGB", (W * (cols + 1), H * (len(row_lens) * rows + 1)), "black") # Place image mixes respectively at each position (row_idx, col_idx) for row_idx, row_elem in enumerate([None] + list(range(len(row_lens) * rows))): for col_idx, col_elem in enumerate([None] + list(range(cols))): if (row_elem, col_elem) == (None, None): continue if row_elem is None: img = orig_imgs[col_elem] elif col_elem is None: img = orig_imgs[cols + (row_elem % rows)] else: img = mix_imgs[row_elem, col_elem] canvas.paste(misc.to_pil(img, drange = drange_net), (W * col_idx, H * row_idx)) canvas.save(f"{run_dir}/visuals/{name}-mixing.png") if verbose: misc.log("Visualizations Completed!", "blue")
def setup_config(run_dir, **args): args = EasyDict(args) # command-line options train = EasyDict(run_dir=run_dir) # training loop options vis = EasyDict(run_dir=run_dir) # visualization loop options if args.reload: config_fn = os.path.join(run_dir, "training_options.json") if os.path.exists(config_fn): # Load config form the experiment existing file (and so ignore command-line arguments) with open(config_fn, "rt") as f: config = json.load(f) return config misc.log( f"Warning: --reload is set for a new experiment {args.expname}," + f" but configuration file to reload from {config_fn} doesn't exist.", "red") # GANformer and baselines default settings # ---------------------------------------------------------------------------- if args.ganformer_default: task = args.dataset nset(args, "mirror_augment", task in ["cityscapes", "ffhq"]) nset(args, "transformer", True) nset(args, "components_num", {"clevr": 8}.get(task, 16)) nset(args, "latent_size", {"clevr": 128}.get(task, 512)) nset(args, "normalize", "layer") nset(args, "integration", "mul") nset(args, "kmeans", True) nset(args, "use_pos", True) nset(args, "mapping_ltnt2ltnt", task != "clevr") nset(args, "style", task != "clevr") nset(args, "g_arch", "resnet") nset(args, "mapping_resnet", True) gammas = {"ffhq": 10, "cityscapes": 20, "clevr": 40, "bedrooms": 100} nset(args, "gamma", gammas.get(task, 10)) if args.baseline == "GAN": nset(args, "style", False) nset(args, "latent_stem", True) ## k-GAN and SAGAN are not currently supported in the pytorch version. ## See the TF version for implementation of these baselines! # if args.baseline == "SAGAN": # nset(args, "style", False) # nset(args, "latent_stem", True) # nset(args, "g_img2img", 5) # if args.baseline == "kGAN": # nset(args, "kgan", True) # nset(args, "merge_layer", 5) # nset(args, "merge_type", "softmax") # nset(args, "components_num", 8) # General setup # ---------------------------------------------------------------------------- # If the flag is specified without arguments (--arg), set to True for arg in [ "cuda_bench", "allow_tf32", "keep_samples", "style", "local_noise" ]: if args[arg] is None: args[arg] = True if not any([args.train, args.eval, args.vis]): misc.log( "Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes", "red") for arg in ["train", "eval", "vis", "last_snapshots"]: cset(train, arg, args[arg]) if args.gpus != "": num_gpus = len(args.gpus.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not (num_gpus >= 1 and num_gpus & (num_gpus - 1) == 0): misc.error("Number of GPUs must be a power of two") args.num_gpus = num_gpus # CUDA settings for arg in ["batch_size", "batch_gpu", "allow_tf32"]: cset(train, arg, args[arg]) cset(train, "cudnn_benchmark", args.cuda_bench) # Data setup # ---------------------------------------------------------------------------- # For bedrooms, we choose the most common ratio in the # dataset and crop the other images into that ratio. ratios = { "clevr": 0.75, "bedrooms": 188 / 256, "cityscapes": 0.5, "ffhq": 1.0 } args.ratio = args.ratio or ratios.get(args.dataset, 1.0) args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio < 0.5 else None args.printname = args.expname for arg in ["total_kimg", "printname"]: cset(train, arg, args[arg]) dataset_args = EasyDict(class_name="training.dataset.ImageFolderDataset", path=f"{args.data_dir}/{args.dataset}", max_items=args.train_images_num, resolution=args.resolution, ratio=args.ratio, mirror_augment=args.mirror_augment) dataset_args.loader_args = EasyDict(num_workers=args.num_threads, pin_memory=True, prefetch_factor=2) # Optimization setup # ---------------------------------------------------------------------------- cG = set_net("Generator", ["mapping", "synthesis"], args.g_lr, 4) cD = set_net("Discriminator", ["mapping", "block", "epilogue"], args.d_lr, 16) cset([cG, cD], "crop_ratio", args.crop_ratio) mbstd = min( args.batch_gpu, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed cset(cD.epilogue_kwargs, "mbstd_group_size", mbstd) # Automatic tuning if args.autotune: batch_size = max( min(args.num_gpus * min(4096 // args.resolution, 32), 64), args.num_gpus) # keep gpu memory consumption at bay batch_gpu = args.batch_size // args.num_gpus nset(args, "batch_size", batch_size) nset(args, "batch_gpu", batch_gpu) fmap_decay = 1 if args.resolution >= 512 else 0.5 # other hyperparams behave more predictably if mbstd group size remains fixed lr = 0.002 if args.resolution >= 1024 else 0.0025 gamma = 0.0002 * (args.resolution** 2) / args.batch_size # heuristic formula cset([cG.synthesis_kwargs, cD], "dim_base", int(fmap_decay * 32768)) nset(args, "g_lr", lr) cset(cG.opt_args, "lr", args.g_lr) nset(args, "d_lr", lr) cset(cD.opt_args, "lr", args.d_lr) nset(args, "gamma", gamma) train.ema_rampup = 0.05 train.ema_kimg = batch_size * 10 / 32 if args.batch_size % (args.batch_gpu * args.num_gpus) != 0: misc.error( "--batch-size should be divided by --batch-gpu * 'num_gpus'") # Loss and regularization settings loss_args = EasyDict(class_name="training.loss.StyleGAN2Loss", g_loss=args.g_loss, d_loss=args.d_loss, r1_gamma=args.gamma, pl_weight=args.pl_weight) # if args.fp16: # cset([cG.synthesis_kwargs, cD], "num_fp16_layers", 4) # enable mixed-precision training # cset([cG.synthesis_kwargs, cD], "conv_clamp", 256) # clamp activations to avoid float16 overflow # cset([cG.synthesis_kwargs, cD.block_args], "fp16_channels_last", args.nhwc) # Evaluation and visualization # ---------------------------------------------------------------------------- from metrics import metric_main for metric in args.metrics: if not metric_main.is_valid_metric(metric): misc.error( f"Unknown metric: {metric}. The valid metrics are: {metric_main.list_valid_metrics()}" ) for arg in ["num_gpus", "metrics", "eval_images_num", "truncation_psi"]: cset(train, arg, args[arg]) for arg in ["keep_samples", "num_heads"]: cset(vis, arg, args[arg]) args.vis_imgs = args.vis_images args.vis_ltnts = args.vis_latents vis_types = [ "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var", "style_mix" ] # Set of all the set visualization types option vis.vis_types = list({arg for arg in vis_types if args[f"vis_{arg}"]}) vis_args = { "attention": "transformer", "grid": "vis_grid", "num": "vis_num", "rich_num": "vis_rich_num", "section_size": "vis_section_size", "intrp_density": "interpolation_density", # "intrp_per_component": "interpolation_per_component", "alpha": "blending_alpha" } for arg, cmd_arg in vis_args.items(): cset(vis, arg, args[cmd_arg]) # Networks setup # ---------------------------------------------------------------------------- # Networks architecture cset(cG.synthesis_kwargs, "architecture", args.g_arch) cset(cD, "architecture", args.d_arch) # Latent sizes if args.components_num > 0: if not args.transformer: # or args.kgan): misc.error( "--components-num > 0 but the model is not using components. " + "Add --transformer for GANformer (which uses latent components)." ) if args.latent_size % args.components_num != 0: misc.error( f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})" ) args.latent_size = int(args.latent_size / args.components_num) cG.z_dim = cG.w_dim = args.latent_size cset([cG, vis], "k", args.components_num + 1) # We add a component to modulate features globally # Mapping network args.mapping_layer_dim = args.mapping_dim for arg in ["num_layers", "layer_dim", "resnet", "shared", "ltnt2ltnt"]: field = f"mapping_{arg}" cset(cG.mapping_kwargs, arg, args[field]) # StyleGAN settings for arg in ["style", "latent_stem", "local_noise"]: cset(cG.synthesis_kwargs, arg, args[arg]) # GANformer cset([cG.synthesis_kwargs, cG.mapping_kwargs], "transformer", args.transformer) # Attention related settings for arg in ["use_pos", "num_heads", "ltnt_gate", "attention_dropout"]: cset([cG.mapping_kwargs, cG.synthesis_kwargs], arg, args[arg]) # Attention types and layers for arg in ["start_res", "end_res" ]: # , "local_attention" , "ltnt2ltnt", "img2img", "img2ltnt" cset(cG.synthesis_kwargs, arg, args[f"g_{arg}"]) # Mixing and dropout for arg in ["style_mixing", "component_mixing"]: cset(loss_args, arg, args[arg]) cset(cG, "component_dropout", args["component_dropout"]) # Extra transformer options args.norm = args.normalize for arg in [ "norm", "integration", "img_gate", "iterative", "kmeans", "kmeans_iters" ]: cset(cG.synthesis_kwargs, arg, args[arg]) # Positional encoding # args.pos_dim = args.pos_dim or args.latent_size for arg in ["dim", "type", "init", "directions_num"]: field = f"pos_{arg}" cset(cG.synthesis_kwargs, field, args[field]) # k-GAN # for arg in ["layer", "type", "same"]: # field = "merge_{}".format(arg) # cset(cG.args, field, args[field]) # cset(cG.synthesis_kwargs, "merge", args.kgan) # if args.kgan and args.transformer: # misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both") config = EasyDict(train) config.update(cG=cG, cD=cD, loss_args=loss_args, dataset_args=dataset_args, vis_args=vis) # Save config file with open(os.path.join(run_dir, "training_options.json"), "wt") as f: json.dump(config, f, indent=2) return config