def prepare(tasks, data_dir, shards_num=1, max_images=None, ratio=1.0, images_dir=None, format=None): # Options for custom dataset mkdir(data_dir) for task in tasks: # If task not in catalog, create custom task configuration c = catalog.get( task, { "local": True, "name": task, "dir": images_dir, "ratio": ratio, "process": formats_catalog.get(format) }) dirname = "{}/{}".format(data_dir, task) mkdir(dirname) # try: print(misc.bold("Preparing the {} dataset...".format(c.name))) fname = "{}/{}".format(dirname, c.filename) if "local" not in c: download = not ((os.path.exists(fname) and verify_md5(fname, c.md5))) path = get_path(c.url, dirname, path=c.filename) if download: print( misc.bold("Downloading the data ({} GB)...".format( c.size))) download_file(c.url, path) # print(misc.bold("Completed downloading {}".format(c.name))) if path.endswith(".zip"): if not is_unzipped(path, dirname): print(misc.bold("Unzipping {}...".format(path))) unzip(path, dirname) # print(misc.bold("Completed unzipping {}".format(path))) if "process" in c: imgdir = images_dir if "local" in c else ("{}/{}".format( dirname, c.dir)) shards_num = c.shards if max_images is None else shards_num c.process(dirname, imgdir, ratio=c.ratio, shards_num=shards_num, max_imgs=max_images) print( misc.bcolored("Completed preparations for {}!".format(c.name), "blue"))
def load_dataset(class_name = None, data_dir = None, verbose = False, **kwargs): kwargs = dict(kwargs) if "tfrecord_dir" in kwargs: if class_name is None: class_name = __name__ + ".TFRecordDataset" if data_dir is not None: kwargs["tfrecord_dir"] = os.path.join(data_dir, kwargs["tfrecord_dir"]) assert class_name is not None if verbose: print(misc.bcolored("Streaming data using %s %s..." % (class_name, data_dir), "white")) dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs) if verbose: print("Dataset shape: ", misc.bcolored(np.int32(dataset.shape).tolist(), "blue")) print("Dynamic range: ", misc.bcolored(dataset.dynamic_range, "blue")) if dataset.label_size > 0: print("Label size: ", misc.bcolored(dataset.label_size, "blue")) return dataset
def load_nets(resume_pkl, lG, lD, lGs, recompile): print(misc.bcolored("Loading networks from %s..." % resume_pkl, "white")) rG, rD, rGs = misc.load_pkl(resume_pkl)[:3] if recompile: print(misc.bold("Copying nets...")); lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs) else: lG, lD, lGs = rG, rD, rGs return lG, lD, lGs
def task(task, name, size, dir, redownload, download = None, prepare = lambda: None): if task: print(misc.bcolored("Preparing the {} dataset...".format(name), "blue")) if download is not None and (redownload or not path.exists("{}/{}".format(dir, task))): print(misc.bold("Downloading the data ({} GB)...".format(size), "blue")) download() print(misc("Completed downloading {}".format(name))) prepare() print(misc.bold("Completed preparations for {}!".format(name))) except:
def print_results(jsonl_line): print(" " * 100 + "\r") if jsonl_line["snapshot_pkl"] is None: network_name = "None" else: network_name = os.path.splitext( os.path.basename(jsonl_line["snapshot_pkl"]))[0] if len(network_name) > 29: network_name = "..." + network_name[-26:] result_str = "%-30s" % network_name result_str += " time %-12s" % dnnlib.util.format_time( jsonl_line["total_time"]) nums = "" for res, value in jsonl_line["results"].items(): nums += f" {res} {value:10.4f}" nums = misc.bcolored(nums, "blue") result_str += nums print(result_str)
def get_result_str(self, screen=False): if self._network_pkl is None: network_name = "None" else: network_name = os.path.splitext(os.path.basename( self._network_pkl))[0] if len(network_name) > 29: network_name = "..." + network_name[-26:] result_str = "%-30s" % network_name result_str += " time %-12s" % dnnlib.util.format_time(self._eval_time) nums = "" for res in self._results: nums += " " + self.name + res.suffix + " " nums += res.fmt % res.value if screen: nums = misc.bcolored(nums, "blue") result_str += nums return result_str
def prepare(tasks, data_dir, max_images = None, ratio = 1.0, images_dir = None, format = None): # Options for custom dataset os.makedirs(data_dir, exist_ok = True) for task in tasks: # If task not in catalog, create custom task configuration c = catalog.get(task, EasyDict({ "local": True, "name": task, "dir": images_dir, "ratio": ratio, "process": formats_catalog.get(format) })) dirname = f"{data_dir}/{task}" os.makedirs(dirname, exist_ok = True) # try: print(misc.bold(f"Preparing the {c.name} dataset...")) if "local" not in c: fname = f"{dirname}/{c.filename}" download = not ((os.path.exists(fname) and verify_md5(fname, c.md5))) path = get_path(c.url, dirname, path = c.filename) if download: print(misc.bold(f"Downloading the data ({c.size} GB)...")) download_file(c.url, path) if path.endswith(".zip"): if not is_unzipped(path, dirname): print(misc.bold(f"Unzipping {path}...")) unzip(path, dirname) if "process" in c: imgdir = images_dir if "local" in c else (f"{dirname}/{c.dir}") c.process(dirname, imgdir, ratio = c.ratio, max_imgs = max_images) print(misc.bcolored(f"Completed preparations for {c.name}!", "blue"))
def run(**args): args = EasyDict(args) train = EasyDict(run_func_name="training.training_loop.training_loop" ) # training loop options sched = EasyDict() # TrainingSchedule options vis = EasyDict() # visualize.eval() options grid = EasyDict(size="1080p", layout="random") # setup_snapshot_img_grid() options sc = dnnlib.SubmitConfig() # dnnlib.submit_run() options # Environment configuration tf_config = { "rnd.np_random_seed": 1000, "allow_soft_placement": True, "gpu_options.per_process_gpu_memory_fraction": 1.0 } if args.gpus != "": num_gpus = len(args.gpus.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus assert num_gpus in [1, 2, 4, 8] sc.num_gpus = num_gpus # Networks configuration cG = set_net("G", reg_interval=4) cD = set_net("D", reg_interval=16) # Dataset configuration ratios = { "clevr": 0.75, "lsun-bedrooms": 0.72, "cityscapes": 0.5, "ffhq": 1.0 } args.ratio = ratios.get(args.dataset, args.ratio) dataset_args = EasyDict(tfrecord_dir=args.dataset, max_imgs=args.train_images_num, ratio=args.ratio, num_threads=args.num_threads) for arg in ["data_dir", "mirror_augment", "total_kimg"]: cset(train, arg, args[arg]) # Training and Optimizations configuration for arg in ["eval", "train", "recompile", "last_snapshots"]: cset(train, arg, args[arg]) # Round to the closest multiply of minibatch size for validity args.batch_size -= args.batch_size % args.minibatch_size args.minibatch_std_size -= args.minibatch_std_size % args.minibatch_size args.latent_size -= args.latent_size % args.components_num if args.latent_size == 0: print( misc.bcolored( "Error: latent-size is too small. Must best a multiply of components-num.", "red")) exit() sched_args = { "G_lrate": "g_lr", "D_lrate": "d_lr", "minibatch_size": "batch_size", "minibatch_gpu": "minibatch_size" } for arg, cmd_arg in sched_args.items(): cset(sched, arg, args[cmd_arg]) cset(train, "clip", args.clip) # Logging and metrics configuration metrics = [metric_defaults[x] for x in args.metrics] cset(cG.args, "truncation_psi", args.truncation_psi) for arg in ["summarize", "keep_samples", "eval_images_num"]: cset(train, arg, args[arg]) # Visualization args.vis_imgs = args.vis_images args.vis_ltnts = args.vis_latents vis_types = [ "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var", "style_mix" ] # Set of all the set visualization types option vis.vis_types = {arg for arg in vis_types if args["vis_{}".format(arg)]} vis_args = { "attention": "transformer", "grid": "vis_grid", "num": "vis_num", "rich_num": "vis_rich_num", "section_size": "vis_section_size", "intrp_density": "intrpolation_density", "intrp_per_component": "intrpolation_per_component", "alpha": "blending_alpha" } for arg, cmd_arg in vis_args.items(): cset(vis, arg, args[cmd_arg]) # Networks architecture cset(cG.args, "architecture", args.g_arch) cset(cD.args, "architecture", args.d_arch) cset(cG.args, "tanh", args.tanh) # Latent sizes if args.components_num > 1: if not (args.transformer or args.kgan): print( misc.bcolored( "Error: components-num > 1 but the model is not using components.", "red")) print( misc.bcolored( "Either add --transformer for GANsformer or --kgan for k-GAN).", "red")) exit() args.latent_size = int(args.latent_size / args.components_num) cD.args.latent_size = cG.args.latent_size = cG.args.dlatent_size = args.latent_size cset([cG.args, cD.args, vis], "components_num", args.components_num) # mapping-dim default value is the latent-size if args.mapping_dim is None: args.mapping_dim = args.latent_size # Mapping network for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]: cset(cG.args, arg, args["mapping_{}".format(arg)]) # StyleGAN settings for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]: cset(cG.args, arg, args[arg]) cD.args.mbstd_group_size = args.minibatch_std_size # GANsformer cset(cG.args, "transformer", args.transformer) cset(cD.args, "transformer", args.d_transformer) args.norm = args.normalize for arg in [ "norm", "integration", "ltnt_gate", "img_gate", "kmeans", "kmeans_iters", "mapping_ltnt2ltnt" ]: cset(cG.args, arg, args[arg]) for arg in ["use_pos", "num_heads"]: cset([cG.args, cD.args], arg, args[arg]) # Positional encoding for arg in ["dim", "init", "directions_num"]: field = "pos_{}".format(arg) cset([cG.args, cD.args], field, args[field]) # k-GAN for arg in ["layer", "type", "same"]: field = "merge_{}".format(arg) cset(cG.args, field, args[field]) cset([cG.args, train], "merge", args.kgan) # Attention for arg in [ "start_res", "end_res", "ltnt2ltnt", "img2img", "local_attention" ]: cset(cG.args, arg, args["g_{}".format(arg)]) cset(cD.args, arg, args["d_{}".format(arg)]) cset(cG.args, "img2ltnt", args.g_img2ltnt) cset(cD.args, "ltnt2img", args.d_ltnt2img) # Mixing and dropout for arg in [ "style_mixing", "component_mixing", "component_dropout", "attention_dropout" ]: cset(cG.args, arg, args[arg]) # Loss and regularization gloss_args = { "loss_type": "g_loss", "reg_weight": "g_reg_weight", "pathreg": "pathreg", } dloss_args = {"loss_type": "d_loss", "reg_type": "d_reg", "gamma": "gamma"} for arg, cmd_arg in gloss_args.items(): cset(cG.loss_args, arg, args[cmd_arg]) for arg, cmd_arg in dloss_args.items(): cset(cD.loss_args, arg, args[cmd_arg]) ##### Experiments management: # Whenever we start a new experiment we store its result in a directory named 'args.expname:000'. # When we rerun a training or evaluation command it restores the model from that directory by default. # If we wish to restart the model training, we can set --restart and then we will store data in a new # directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc. # Find the latest directory that matches the experiment exp_dir = sorted(glob.glob("{}/{}:*".format(args.result_dir, args.expname))) run_id = 0 if len(exp_dir) > 0: run_id = int(exp_dir[-1].split(":")[-1]) # If restart, then work over a new directory if args.restart: run_id += 1 run_name = "{}:{:03d}".format(args.expname, run_id) train.printname = "{} ".format(misc.bold(args.expname)) snapshot, kimg, resume = None, 0, False pkls = sorted( glob.glob("{}/{}/network*.pkl".format(args.result_dir, run_name))) # Load a particular snapshot is specified if args.pretrained_pkl: # Soft links support snapshot = glob.glob(args.pretrained_pkl)[0] if os.path.islink(snapshot): snapshot = os.readlink(snapshot) # Extract training step from the snapshot if specified try: kimg = int(snapshot.split("-")[-1].split(".")[0]) except: pass # Find latest snapshot in the directory elif len(pkls) > 0: snapshot = pkls[-1] kimg = int(snapshot.split("-")[-1].split(".")[0]) resume = True if snapshot: print( misc.bcolored("Resuming {}, kimg {}".format(snapshot, kimg), "white")) train.resume_pkl = snapshot train.resume_kimg = kimg else: print(misc.bcolored("Start model training from scratch.", "white")) # Run environment configuration sc.run_dir_root = args.result_dir sc.run_desc = args.expname sc.run_id = run_id sc.run_name = run_name sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(cG=cG, cD=cD) kwargs.update(dataset_args=dataset_args, vis_args=vis, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) kwargs.submit_config = copy.deepcopy(sc) kwargs.resume = resume kwargs.load_config = args.reload dnnlib.submit_run(**kwargs)
def main(): parser = argparse.ArgumentParser( description="Train the GANsformer", epilog=_examples, formatter_class=argparse.RawDescriptionHelpFormatter) # Framework # ------------------------------------------------------------------------------------------------------ parser.add_argument("--expname", help="Experiment name", default="exp", type=str) parser.add_argument("--eval", help="Evaluation mode (default: False)", default=None, action="store_true") parser.add_argument("--train", help="Train mode (default: False)", default=None, metavar="BOOL", type=_str_to_bool) parser.add_argument( "--gpus", help="Comma-separated list of GPUs to be used (default: %(default)s)", default="0", type=str) ## Resumption parser.add_argument("--pretrained_pkl", help="Filename for a snapshot to resume (optional)", default=None, type=str) parser.add_argument("--restart", help="Restart training from scratch", default=False, action="store_true") parser.add_argument( "--reload", help="Reload options from the original experiment configuration file. " + "If False, uses the command line arguments when resuming training (default: %(default)s)", default=False, action="store_true") parser.add_argument( "--recompile", help="Recompile model from source code when resuming training. " + "If False, loading modules created when the experiment first started", default=None, action="store_true") parser.add_argument( "--last_snapshots", help="Number of last snapshots to save. -1 for all (default: 8)", default=None, type=int) ## Dataset parser.add_argument("--data-dir", help="Datasets root directory", required=True) parser.add_argument( "--dataset", help="Training dataset name (subdirectory of data-dir).", required=True) parser.add_argument("--ratio", help="Image height/width ratio in the dataset", default=1.0, type=float) parser.add_argument( "--num-threads", help="Number of input processing threads (default: %(default)s)", default=4, type=int) parser.add_argument( "--mirror-augment", help= "Perform horizontal flip augmentation for the data (default: %(default)s)", default=False) parser.add_argument( "--train-images-num", help= "Maximum number of images to train on. If not specified, train on the whole dataset.", default=None, type=int) ## Training parser.add_argument( "--batch-size", help="Global batch size (optimization step) (default: %(default)s)", default=32, type=int) parser.add_argument( "--minibatch-size", help= "Batch size per GPU, gradients will be accumulated to match batch-size (default: %(default)s)", default=4, type=int) parser.add_argument( "--total-kimg", help="Training length in thousands of images (default: %(default)s)", metavar="KIMG", default=25000, type=int) parser.add_argument("--gamma", help="R1 regularization weight (default: %(default)s)", default=10, type=float) parser.add_argument("--clip", help="Gradient clipping threshold (optional)", default=None, type=float) parser.add_argument("--g-lr", help="Generator learning rate (default: %(default)s)", default=0.002, type=float) parser.add_argument( "--d-lr", help="Discriminator learning rate (default: %(default)s)", default=0.002, type=float) ## Logging and evaluation parser.add_argument( "--result-dir", help="Root directory for experiments (default: %(default)s)", default="results", metavar="DIR") parser.add_argument( "--metrics", help="Comma-separated list of metrics or none (default: %(default)s)", default="fid", type=_parse_comma_sep) parser.add_argument( "--summarize", help="Create TensorBoard summaries (default: %(default)s)", default=True, metavar="BOOL", type=_str_to_bool) parser.add_argument( "--truncation-psi", help="Truncation Psi to be used in producing sample images " + "(used only for visualizations, _not used_ in training or for computing metrics) (default: %(default)s)", default=0.65, type=float) parser.add_argument( "--keep-samples", help= "Keep all prior samples during training, or if False, just the most recent ones (default: False)", default=None, action="store_true") parser.add_argument( "--eval-images-num", help="Number of images to evaluate metrics on (default: 50,000)", default=None, type=float) ## Visualization parser.add_argument("--vis-images", help="Save image samples", default=None, action="store_true") parser.add_argument("--vis-latents", help="Save latent vectors", default=None, action="store_true") parser.add_argument("--vis-maps", help="Save attention maps (for GANsformer only)", default=None, action="store_true") parser.add_argument( "--vis-layer-maps", help="Save attention maps for all layers (for GANsformer only)", default=None, action="store_true") parser.add_argument("--vis-interpolations", help="Create latent interpolations", default=None, action="store_true") parser.add_argument("--vis-noise-var", help="Create noise variation visualization", default=None, action="store_true") parser.add_argument("--vis-style-mix", help="Create style mixing visualization", default=None, action="store_true") parser.add_argument( "--vis-grid", help= "Whether to save the samples in one large grid files (default: True in training)", default=None, action="store_true") parser.add_argument("--vis-num", help="Image height/width ratio in the dataset", default=None, type=int) parser.add_argument( "--vis-rich-num", help= "Number of samples for which richer visualizations will be created (default: 5)", default=None, type=int) parser.add_argument( "--vis-section-size", help= "Visualization section size to process at one (section-size <= vis-num) for memory footprint (default: 100)", default=None, type=int) parser.add_argument( "--blending-alpha", help= "Proportion for generated images and attention maps blends (default: 0.3)", default=None, type=float) parser.add_argument( "--intrpolation-density", help= "Number of samples in between two end points of an interpolation (default: 8)", default=None, type=int) parser.add_argument( "--intrpolation-per-component", help= "Whether to perform interpolation along particular latent components when true, or all of them at once otherwise (default: False)", default=None, action="store_true") # Model # ------------------------------------------------------------------------------------------------------ ## General architecture parser.add_argument("--g-arch", help="Generator architecture type (default: skip)", default=None, choices=["orig", "skip", "resnet"], type=str) parser.add_argument( "--d-arch", help="Discriminator architecture type (default: resnet)", default=None, choices=["orig", "skip", "resnet"], type=str) parser.add_argument("--tanh", help="tanh on generator output (default: False)", default=None, action="store_true") # Mapping network parser.add_argument("--mapping-layersnum", help="Number of mapping layers (default: 8)", default=None, type=int) parser.add_argument( "--mapping-lrmul", help="Mapping network learning rate multiplier (default: 0.01)", default=None, type=float) parser.add_argument("--mapping-dim", help="Mapping layers dimension (default: latent_size)", default=None, type=int) parser.add_argument( "--mapping-resnet", help="Use resent connections in mapping layers (default: False)", default=None, action="store_true") parser.add_argument( "--mapping-shared-dim", help= "Perform one shared mapping to all latent components concatenated together using the set dimension (default: disabled)", default=None, type=int) # Loss parser.add_argument( "--pathreg", help="Use path regularization in generator training (default: False", default=None, metavar="BOOL", type=_str_to_bool) parser.add_argument("--g-loss", help="Generator loss type (default: %(default)s)", default="logistic_ns", choices=["logistic", "logistic_ns", "hinge", "wgan"], type=str) parser.add_argument( "--g-reg-weight", help="Generator regularization weight (default: %(default)s)", default=1.0, type=float) parser.add_argument("--d-loss", help="Discriminator loss type (default: %(default)s)", default="logistic", choices=["wgan", "logistic", "hinge"], type=str) parser.add_argument( "--d-reg", help="Discriminator regularization type (default: %(default)s)", default="r1", choices=["non", "gp", "r1", "r2"], type=str) # Mixing and dropout parser.add_argument( "--style-mixing", help="Style mixing (layerwise) probability (default: %(default)s)", default=0.9, type=float) parser.add_argument( "--component-mixing", help="Component mixing (objectwise) probability (default: %(default)s)", default=0.0, type=float) parser.add_argument("--component-dropout", help="Component dropout (default: %(default)s)", default=0.0, type=float) parser.add_argument("--attention-dropout", help="Attention dropout (default: 0.12)", default=None, type=float) # StyleGAN additions parser.add_argument("--style", help="Global style modulation (default: %(default)s)", default=True, metavar="BOOL", type=_str_to_bool) parser.add_argument( "--latent-stem", help="Input latent through the generator stem grid (default: False)", default=None, action="store_true") parser.add_argument( "--fused-modconv", help= "Fuse modulation and convolution operations (default: %(default)s)", default=True, metavar="BOOL", type=_str_to_bool) parser.add_argument( "--local-noise", help="Add stochastic local noise each layer (default: %(default)s)", default=True, metavar="BOOL", type=_str_to_bool) parser.add_argument( "--minibatch-std-size", help= "Add minibatch standard deviation layer in the discriminator (default: %(default)s)", default=4, type=int) ## GANsformer parser.add_argument( "--transformer", help= "Add transformer layers to the generator: top-down latents-to-image (default: False)", default=None, action="store_true") parser.add_argument( "--latent-size", help= "Latent size, summing the dimension of all components (default: %(default)s)", default=512, type=int) parser.add_argument( "--components-num", help= "Components number. Each component has latent dimension of 'latent-size / components-num'. " + "1 for StyleGAN since it has one global latent vector (default: %(default)s)", default=1, type=int) parser.add_argument( "--num-heads", help="Number of attention heads (default: %(default)s)", default=1, type=int) parser.add_argument("--normalize", help="Feature normalization type (optional)", default=None, choices=["batch", "instance", "layer"]) parser.add_argument( "--integration", help= "Feature integration type: additive, multiplicative or both (default: %(default)s)", default="add", choices=["add", "mul", "both"], type=str) # Generator attention layers # Transformer resolution layers parser.add_argument( "--g-start-res", help= "Transformer minimum generator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)", default=0, type=int) parser.add_argument( "--g-end-res", help= "Transformer maximum generator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)", default=7, type=int) # Discriminator attention layers parser.add_argument( "--d-transformer", help= "Add transformer layers to the discriminator (bottom-up image-to-latents) (default: False)", default=None, action="store_true") parser.add_argument( "--d-start-res", help= "Transformer minimum discriminator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)", default=0, type=int) parser.add_argument( "--d-end-res", help= "Transformer maximum discriminator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)", default=7, type=int) # Attention parser.add_argument( "--ltnt-gate", help= "Gate attention from latents, such that components may not send information " + "when gate value is low (default: False)", default=None, action="store_true") parser.add_argument( "--img-gate", help= "Gate attention for images, such that some image positions may not get updated " + "or receive information when gate value is low (default: False)", default=None, action="store_true") parser.add_argument( "--kmeans", help= "Track and update image-to-latents assignment centroids, used in the duplex attention (default: False)", default=None, action="store_true") parser.add_argument( "--kmeans-iters", help= "Number of K-means iterations per transformer layer. Note that centroids are carried from layer to layer (default: %(default)s)", default=1, type=int) # -per-layer # Attention directions # format is A2B: Elements _from_ B attend _to_ elements in A, and B elements get updated accordingly. # Note that it means that information propagates in the following direction: A -> B parser.add_argument( "--mapping-ltnt2ltnt", help= "Add self-attention over latents in the mapping network (default: False)", default=None, action="store_true") parser.add_argument( "--g-ltnt2ltnt", help= "Add self-attention over latents in the synthesis network (default: False)", default=None, action="store_true") parser.add_argument( "--g-img2img", help= "Add self-attention between images positions in that layer of the generator (SAGAN) (default: disabled)", default=0, type=int) parser.add_argument( "--g-img2ltnt", help= "Add image to latents attention (bottom-up) (default: %(default)s)", default=None, action="store_true") # g-ltnt2img: default information flow direction when using --transformer parser.add_argument( "--d-ltnt2img", help="Add latents to image attention (top-down) (default: %(default)s)", default=None, action="store_true") parser.add_argument( "--d-ltnt2ltnt", help= "Add self-attention over latents in the discriminator (default: False)", default=None, action="store_true") parser.add_argument( "--d-img2img", help= "Add self-attention over images positions in that layer of the discriminator (SAGAN) (default: disabled)", default=0, type=int) # d-img2ltnt: default information flow direction when using --d-transformer # Local attention operations (replacing convolution) parser.add_argument( "--g-local-attention", help= "Local attention operations in the generation up to this layer (default: disabled)", default=None, type=int) parser.add_argument( "--d-local-attention", help= "Local attention operations in the discriminator up to this layer (default: disabled)", default=None, type=int) # Positional encoding parser.add_argument("--use-pos", help="Use positional encoding (default: False)", default=None, action="store_true") parser.add_argument( "--pos-dim", help="Positional encoding dimension (default: latent-size)", default=None, type=int) parser.add_argument( "--pos-type", help="Positional encoding type (default: %(default)s)", default="sinus", choices=["linear", "sinus", "trainable", "trainable2d"], type=str) parser.add_argument( "--pos-init", help= "Positional encoding initialization distribution (default: %(default)s)", default="uniform", choices=["uniform", "normal"], type=str) parser.add_argument( "--pos-directions-num", help= "Positional encoding number of spatial directions (default: %(default)s)", default=2, type=int) ## k-GAN parser.add_argument( "--kgan", help= "Generate components-num images and then merge them (k-GAN) (default: False)", default=None, action="store_true") parser.add_argument( "--merge-layer", help= "Merge layer, where images get combined through alpha-composition (default: %(default)s)", default=-1, type=int) parser.add_argument("--merge-type", help="Merge type (default: additive)", default=None, choices=["sum", "softmax", "max", "leaves"], type=str) parser.add_argument( "--merge-same", help= "Merge images with same alpha weights across all spatial positions (default: %(default)s)", default=None, action="store_true") args = parser.parse_args() if not os.path.exists(args.data_dir): print( misc.bcolored("Error: dataset root directory does not exist.", "red")) exit() for metric in args.metrics: if metric not in metric_defaults: print(misc.bcolored("Error: unknown metric \"%s\"" % metric, "red")) exit() run(**vars(args))
def __init__( self, tfrecord_dir, # Directory containing a collection of tfrecords files resolution=None, # Dataset resolution, None = autodetect label_file=None, # Relative path of the labels file, None = autodetect max_label_size=0, # 0 = no labels, "full" = full labels, <int> = N first label components max_imgs=None, # Maximum number of images to use, None = use all images repeat=True, # Repeat dataset indefinitely? shuffle_mb=2048, # Shuffle data within specified window (megabytes), 0 = disable shuffling prefetch_mb=512, # Amount of data to prefetch (megabytes), 0 = disable prefetching buffer_mb=256, # Read buffer size (megabytes) num_threads=4, # Number of concurrent threads for input processing **kwargs): self.tfrecord_dir = tfrecord_dir self.resolution = None self.resolution_log2 = None self.shape = [] # [channels, height, width] self.dtype = "uint8" self.dynamic_range = [0, 255] self.label_file = label_file self.label_size = None self.label_dtype = None self._np_labels = None self._tf_minibatch_in = None self._tf_labels_var = None self._tf_labels_dataset = None self._tf_datasets = dict() self._tf_iterator = None self._tf_init_ops = dict() self._tf_minibatch_np = None self._cur_minibatch = -1 self._cur_lod = -1 # List tfrecords files and inspect their shapes assert os.path.isdir(self.tfrecord_dir) tfr_files = sorted( glob.glob(os.path.join(self.tfrecord_dir, "*.tfrecords1of*"))) # If max_imgs is not None, take a subset of images out of the 1st file. Otherwise take all files. if max_imgs is None: tfr_files = [ sorted(glob.glob(re.sub("1of.*", "*", f))) for f in tfr_files ] else: tfr_files = [[f] for f in tfr_files] assert len(tfr_files) >= 1 tfr_shapes = [] for tfr_file in tfr_files: tfr_opt = tf.io.TFRecordOptions("") for record in tf.python_io.tf_record_iterator( tfr_file[0], tfr_opt): tfr_shapes.append(self.parse_tfrecord_np(record).shape) break random.shuffle(tfr_file) # Autodetect label filename if self.label_file is None: guess = sorted( glob.glob(os.path.join(self.tfrecord_dir, "*.labels"))) if len(guess): self.label_file = guess[0] elif not os.path.isfile(self.label_file): guess = os.path.join(self.tfrecord_dir, self.label_file) if os.path.isfile(guess): self.label_file = guess # Determine shape and resolution max_shape = max(tfr_shapes, key=np.prod) self.resolution = resolution if resolution is not None else max_shape[1] self.resolution_log2 = int(np.log2(self.resolution)) self.shape = [max_shape[0], self.resolution, self.resolution] tfr_lods = [ self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes ] assert all(shape[0] == max_shape[0] for shape in tfr_shapes) assert all(shape[1] == shape[2] for shape in tfr_shapes) assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods)) assert all(lod in range(self.resolution_log2 - 1) for lod in tfr_lods) # Load labels assert max_label_size == "full" or max_label_size >= 0 self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32) if self.label_file is not None and max_label_size != 0: self._np_labels = np.load(self.label_file) assert self._np_labels.ndim == 2 if max_label_size != "full" and self._np_labels.shape[ 1] > max_label_size: self._np_labels = self._np_labels[:, :max_label_size] if max_imgs is not None and self._np_labels.shape[0] > max_imgs: self._np_labels = self._np_labels[:max_imgs] if max_imgs is not None and self._np_labels.shape[0] < max_imgs: print(misc.bcolored("Too many images. increase number.", "red")) exit() self.label_size = self._np_labels.shape[1] self.label_dtype = self._np_labels.dtype.name # Build TF expressions with tf.name_scope("Dataset"), tf.device("/cpu:0"): self._tf_minibatch_in = tf.placeholder(tf.int64, name="minibatch_in", shape=[]) self._tf_labels_var = tflib.create_var_with_large_initial_value( self._np_labels, name="labels_var") self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices( self._tf_labels_var) for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods): if tfr_lod < 0: continue # Load dataset dset = tf.data.TFRecordDataset(tfr_file, compression_type="", buffer_size=buffer_mb << 20, num_parallel_reads=num_threads) # If max_imgs is set, take a subset of the data if max_imgs is not None: dset = dset.take(max_imgs) # Parse the TF records dset = dset.map(self.parse_tfrecord_tf, num_parallel_calls=num_threads) # Zip images with their labels (0s if no labels) dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset)) # Shuffle and repeat bytes_per_item = np.prod(tfr_shape) * np.dtype( self.dtype).itemsize if shuffle_mb > 0: dset = dset.shuffle(( (shuffle_mb << 20) - 1) // bytes_per_item + 1) if repeat: dset = dset.repeat() # Prefetch and batch if prefetch_mb > 0: dset = dset.prefetch(( (prefetch_mb << 20) - 1) // bytes_per_item + 1) dset = dset.batch(self._tf_minibatch_in) self._tf_datasets[tfr_lod] = dset # Initialize data iterator self._tf_iterator = tf.data.Iterator.from_structure( self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes) self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) \ for lod, dset in self._tf_datasets.items()}
def training_loop( # General configuration train = False, # Training mode eval = False, # Evaluation mode vis = False, # Visualization mode run_dir = ".", # Output directory num_gpus = 1, # Number of GPUs participating in the training rank = 0, # Rank of the current process in [0, num_gpus] cG = {}, # Options for generator network cD = {}, # Options for discriminator network # Data dataset_args = {}, # Options for training set drange_net = [-1,1], # Dynamic range used when feeding image data to the networks # Optimization loss_args = {}, # Options for loss function total_kimg = 25000, # Total length of the training, measured in thousands of real images batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus batch_gpu = 4, # Number of samples processed at a time by one GPU ema_kimg = 10.0, # Half-life of the exponential moving average (EMA) of generator weights ema_rampup = None, # EMA ramp-up coefficient cudnn_benchmark = True, # Enable torch.backends.cudnnbenchmark? allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnnallow_tf32? # Logging resume_pkl = None, # Network pickle to resume training from resume_kimg = 0.0, # Assumed training progress at the beginning # Affects reporting and training schedule kimg_per_tick = 8, # Progress snapshot interval img_snapshot_ticks = 3, # How often to save image snapshots? None = disable network_snapshot_ticks = 3, # How often to save network snapshots? None = disable last_snapshots = 10, # Maximal number of prior snapshots to save printname = "", # Experiment name for logging # Evaluation vis_args = {}, # Options for vis.vis metrics = [], # Metrics to evaluate during training eval_images_num = 50000, # Sample size for the metrics truncation_psi = 0.7 # Style strength multiplier for the truncation trick (used for visualizations only) ): # Initialize start_time = time.time() device = init_cuda(rank, cudnn_benchmark, allow_tf32) log = (rank == 0) dataset, dataset_iter = load_dataset(dataset_args, batch_size, rank, num_gpus, log) # Load training set nets = construct_nets(cG, cD, dataset, device, log) if train else None # Construct networks G, D, Gs = load_nets(resume_pkl, nets, device, log) # Resume from existing pickle print_nets(G, D, batch_gpu, device, log) # Print network summary tables if eval: misc.log("Run evaluation...", log = log) evaluate(Gs, resume_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log) if vis and log: misc.log("Produce visualizations...") visualize.vis(Gs, dataset, device, batch_gpu, drange_net = drange_net, ratio = dataset.ratio, truncation_psi = truncation_psi, **vis_args) if not train: exit() nets = distribute_nets(G, D, Gs, device, num_gpus, log) # Distribute networks across GPUs loss, stages = setup_training_stages(loss_args, G, cG, D, cD, nets, device, log) # Setup training stages (losses and optimizers) grid_size, grid_z, grid_c = init_img_grid(dataset, G.input_shape, device, run_dir, log) # Initialize an image grid logger = init_logger(run_dir, log) # Initialize logs # Train misc.log(f"Training for {total_kimg} kimg...", "white", log) cur_nimg, cur_tick, batch_idx = int(resume_kimg * 1000), 0, 0 tick_start_nimg, tick_start_time = cur_nimg, time.time() stats = None while True: # Fetch training data real_img, real_c, gen_zs, gen_cs = fetch_data(dataset, dataset_iter, G.input_shape, drange_net, device, len(stages), batch_size, batch_gpu) # Execute training stages for stage, gen_z, gen_c in zip(stages, gen_zs, gen_cs): if batch_idx % stage.interval != 0: continue run_training_stage(loss, stage, device, real_img, real_c, gen_z, gen_c, batch_size, batch_gpu, num_gpus) # Update Gs update_ema_network(G, Gs, batch_size, cur_nimg, ema_kimg, ema_rampup) # Update state cur_nimg += batch_size batch_idx += 1 # Perform maintenance tasks once per tick done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line and accumulate the info in logger.collector tick_end_time = time.time() if stats is not None: default = dnnlib.EasyDict({'mean': -1}) fields = [] fields.append("tick " + misc.bold(f"{training_stats.report0('Progress/tick', cur_tick):<5d}")) fields.append("kimg " + misc.bcolored(f"{training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}", "red")) fields.append("") fields.append("loss/reg: G (" + misc.bcolored(f"{stats.get('Loss/G/loss', default).mean:>6.3f}", "blue")) fields.append(misc.bold(f"{stats.get('Loss/G/reg', default).mean:>6.3f}") + ")") fields.append("D "+ misc.bcolored(f"({stats.get('Loss/D/loss', default).mean:>6.3f}", "blue")) fields.append(misc.bold(f"{stats.get('Loss/D/reg', default).mean:>6.3f}") + ")") fields.append("") fields.append("time " + misc.bold(f"{dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}")) fields.append(f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}") fields.append(f"mem: GPU {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}") fields.append(f"CPU {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}") fields.append(misc.bold(printname)) torch.cuda.reset_peak_memory_stats() training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) misc.log(" ".join(fields), log = log) # Save image snapshot if log and (img_snapshot_ticks is not None) and (done or cur_tick % img_snapshot_ticks == 0): visualize.vis(Gs, dataset, device, batch_gpu, training = True, step = cur_nimg // 1000, grid_size = grid_size, latents = grid_z, labels = grid_c, drange_net = drange_net, ratio = dataset.ratio, **vis_args) # Save network snapshot if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data, snapshot_pkl = save_nets(G, D, Gs, cur_nimg, dataset_args, run_dir, num_gpus > 1, last_snapshots, log) # Evaluate metrics evaluate(snapshot_data["Gs"], snapshot_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log, logger, run_dir) del snapshot_data # Collect stats and update logs stats = collect_stats(logger, stages) update_logger(logger, stats, cur_nimg, start_time) cur_tick += 1 tick_start_nimg, tick_start_time = cur_nimg, time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done misc.log("Done!", "blue")
def eval( G, dataset, # The dataset object for accessing the data batch_size, # Visualization batch size training=False, # Training mode latents=None, # Source latents to generate images from labels=None, # Source labels to generate images from (0 if no labels are used) # Model settings components_num=1, # Number of components the model has drange_net=[-1, 1], # Model image output range attention=False, # Whereas the model produces attention maps (for visualization) # Visualization settings vis_types=None, # Visualization types to be created num=100, # Number of produced samples rich_num=5, # Number of samples for which richer visualizations will be created # (requires more memory and disk space, and therefore rich_num < num) grid=None, # Whether to save the samples in one large grid files # or in separated files one per sample grid_size=None, # Grid proportions (w, h) step=None, # Step number to be used in visualization filenames verbose=None, # Verbose print progress messages # Visualization-specific settings alpha=0.3, # Proportion for generated images and attention maps blends intrp_density=8, # Number of samples in between two end points of an interpolation intrp_per_component=False, # Whether to perform interpolation along particular latent components (True) # or all of them at once (False) noise_samples_num=100, # Number of samples used to compute noise variation visualization section_size=100 ): # Visualization section size (section_size <= num) for reducing memory footprint def pattern_of(dir, step, suffix): return "eval/{}/{}%06d.{}".format( dir, "" if step is None else "{}_".format(step), suffix) # For time efficiency, during training save only image and map samples # rather than richer visualizations vis = vis_types if training: vis = {"imgs", "maps"} section_size = num = len(latents) else: if vis is None: vis = {"imgs", "maps", "ltnts", "interpolations", "noise_var"} # Set default options # Save image samples in one grid file during training if grid is None: grid = training # Disable verbose during training if verbose: verbose = not training # If grid size is provided, set number of visualized images accordingly if grid_size is not None: num = np.prod(grid_size) # build image functions save_images = misc.save_images_builder(drange_net, grid_size, grid, verbose) save_blends = misc.save_blends_builder(drange_net, grid_size, grid, verbose, alpha) # Set up logging noise_vars = [ var for name, var in G.subnets.synthesis.vars.items() if name.startswith("noise") ] noise_var_vals = { var: np.random.randn(*var.shape.as_list()) for var in noise_vars } tflib.set_vars(noise_var_vals) # Create directories dirs = [] if "imgs" in vis: dirs += ["images"] if "ltnts" in vis: dirs += ["latents-z", "latents-w"] if "maps" in vis: dirs += ["maps", "softmaps", "blends", "softblends"] if "layer_maps" in vis: dirs += ["layer_maps"] if "interpolations" in vis: dirs += ["interpolations-z", "interpolation-w"] for dir in dirs: misc.mkdir(dnnlib.make_run_dir_path("eval/{}".format(dir))) # Produce visualizations for idx in range(0, num, section_size): curr_size = curr_batch_size(num, idx, section_size) if verbose and num > curr_size: print("--- Batch {}/{}".format(idx + 1, num)) # Compute source latents images will be produced from if latents is None: latents = np.random.randn(curr_size, *G.input_shape[1:]) if labels is None: labels = dataset.get_minibatch_np(curr_size) # Run network over latents and produce images and attention maps if verbose: print("Running network...") images, attmaps_all_layers, wlatents_all_layers = G.run( latents, labels, randomize_noise=False, minibatch_size=batch_size, return_dlatents=True) # is_visualization = True # For memory efficiency, save full information only for a small amount of images attmaps_all_layers = attmaps_all_layers[:rich_num] wlatents = wlatents_all_layers[:, :, 0] # Save image samples if "imgs" in vis: if verbose: print("Saving image samples...") save_images(images, pattern_of("images", step, "png"), idx) # Save latent vectors if "ltnts" in vis: if verbose: print("Saving latents...") misc.save_npys(latents, pattern_of("latents-z", step, "npy"), idx) misc.save_npys(wlatents, pattern_of("latents-w", step, "npy"), idx) # For the GANsformer model, save attention maps if attention: if "maps" in vis: soft_maps = attmaps_all_layers[:, :, -1, 0] pallete = np.expand_dims(misc.get_colors(components_num), axis=[2, 3]) maps = (soft_maps == np.amax(soft_maps, axis=1, keepdims=True)).astype(float) soft_maps = np.sum(pallete * np.expand_dims(soft_maps, axis=2), axis=1) maps = np.sum(pallete * np.expand_dims(maps, axis=2), axis=1) if verbose: print("Saving maps...") save_images(soft_maps, pattern_of("softmaps", step, "png"), idx) save_images(maps, pattern_of("maps", step, "png"), idx) save_blends(maps, images, pattern_of("softblends", step, "png"), idx) save_blends(soft_maps, images, pattern_of("blends", step, "png"), idx) # Save maps from all attention heads and layers # (for efficiency, only for a small number of images) if "layer_maps" in vis: all_maps = [] maps_fakes = np.split(attmaps_all_layers, attmaps_all_layers.shape[2], axis=2) for layer, lmaps in enumerate(maps_fakes): lmaps = np.split(np.squeeze(lmaps, axis=2), mapfakes.shape[3], axis=2) for head, hmap in enumerate(lmaps): hmap = (hmap == np.amax(hmap, axis=1, keepdims=True)).astype(float) hmap = np.sum(pallete * hmap, axis=1) all_maps.append((hmap, "l{}_h{}".format(layer, head))) if verbose: print("Saving layer maps...") for i in trange(rich_num): misc.mkdir( dnnlib.make_run_dir_path("eval/layer_maps/%06d" % i)) for maps, name in tqdm(all_maps): dirname = "eval/layer_maps{}/%06d/{}{}.png".format( "" if step is None else ("/" + step), name) save_images(maps, dirname, idx) # Produce interpolations between pairs or source latents # In the GANsformer case, varying one component at a time if "interpolations" in vis: ts = np.array(np.linspace(0.0, 1.0, num=intrp_density, endpoint=True)) if verbose: print("Generating interpolations...") for i in trange(rich_num): misc.mkdir( dnnlib.make_run_dir_path("eval/interpolations-z/%06d" % i)) misc.mkdir( dnnlib.make_run_dir_path("eval/interpolations-w/%06d" % i)) z = np.random.randn(2, *G.input_shape[1:]) z[0] = latents[i:i + 1] w = G.run(z, labels, randomize_noise=False, return_dlatents=True, minibatch_size=batch_size)[-1] def update(t, fn, ts, dim): if dim == 3: ts = ts[:, np.newaxis] t_ups = [] if intrp_per_component: for c in range(components_num): # copy over all the components except component c that will get interpolated t_up = np.tile( np.copy(t[0])[None], [intrp_density] + [1] * dim) # interpolate component c t_up[:, c] = fn(t[0, c], t[1, c], ts) t_ups.append(t_up) t_up = np.concatenate(t_ups, axis=0) else: t_up = fn(t[0], t[1], ts) return t_up z_up = update(z, slerp, ts, 2) w_up = update(w, lerp, ts, 3) imgs1 = G.run(z_up, labels, randomize_noise=False, minibatch_size=batch_size)[0] imgs2 = G.run(w_up, labels, randomize_noise=False, minibatch_size=batch_size, take_wlatents=True)[0] def save_interpolation(imgs, name): imgs = np.split(imgs, components_num, axis=0) for c in range(components_num): filename = "eval/interpolations_%s/%06d/%02d" % (name, i, c) imgs[c] = [ misc.to_pil(img, drange=drange_net) for img in imgs[c] ] imgs[c][-1].save( dnnlib.make_run_dir_path("{}.png".format(filename))) misc.save_gif( imgs[c], dnnlib.make_run_dir_path("{}.gif".format(filename))) save_interpolation(imgs1, "z") save_interpolation(imgs2, "w") # Compute noise variance map # Shows what areas vary the most given fixed source # latents due to the use of stochastic local noise if "noise_var" in vis: if verbose: print("Generating noise variance...") z = np.tile(np.random.randn(1, *G.input_shape[1:]), [noise_samples_num, 1, 1]) imgs = G.run(z, labels, minibatch_size=batch_size)[0] imgs = np.stack([misc.to_pil(img, drange=drange_net) for img in imgs], axis=0) diff = np.std(np.mean(imgs, axis=3), axis=0) * 4 diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8) PIL.Image.fromarray(diff, "L").save( dnnlib.make_run_dir_path("eval/noise_variance.png")) # Compute style mixing table, varying using the latent A in some of the layers and latent B in rest. # For the GANsformer, also produce component mixes (using latents from A in some of the components, # and latents from B in the rest. if "style_mix" in vis: if verbose: print("Generating style mixes...") cols, rows = 4, 2 row_lens = np.array([2, 5, 8, 11]) # Create latent mixes mixes = { "layer": (np.arange(wlatents_all_layers.shape[2]) < row_lens[:, None]).astype(np.float32)[:, None, None, None, :, None], "component": (np.arange(wlatents_all_layers.shape[1]) < row_lens[:, None]).astype(np.float32)[:, None, None, :, None, None] } ws = wlatents_all_layers[:cols + rows] orig_imgs = images[:cols + rows] col_ltnts = wlatents_all_layers[:cols][None, None] row_ltnts = wlatents_all_layers[cols:cols + rows][None, :, None] for name, mix in mixes.items(): # Produce image mixes mix_ltnts = mix * row_ltnts + (1 - mix) * col_ltnts mix_ltnts = np.reshape(mix_ltnts, [-1, *wlatents_all_layers.shape[1:]]) mix_imgs = G.run(mix_ltnts, labels, randomize_noise=False, take_dlatents=True, minibatch_size=batch_size)[0] mix_imgs = np.reshape( mix_imgs, [len(row_lens) * rows, cols, *mix_imgs.shape[1:]]) # Create image table canvas H, W = mix_imgs.shape[-2:] canvas = PIL.Image.new("RGB", (W * (cols + 1), H * (len(row_lens) * rows + 1)), "black") # Place image mixes respectively at each position (row_idx, col_idx) for row_idx, row_elem in enumerate( [None] + list(range(len(row_lens) * rows))): for col_idx, col_elem in enumerate([None] + list(range(cols))): if (row_elem, col_elem) == (None, None): continue if row_elem is None: img = orig_imgs[col_elem] elif col_elem is None: img = orig_imgs[cols + (row_elem % rows)] else: img = mix_imgs[row_elem, col_elem] canvas.paste(misc.to_pil(img, drange=drange_net), (W * col_idx, H * row_idx)) canvas.save( dnnlib.make_run_dir_path("eval/{}_mixing.png".format(name))) if verbose: print(misc.bcolored("Visualizations Completed!", "blue"))
def training_loop( # Configurations cG={}, cD={}, # Generator and Discriminator command-line arguments dataset_args={}, # dataset.load_dataset() options sched_args={}, # train.TrainingSchedule options vis_args={}, # vis.eval options grid_args={}, # train.setup_snapshot_img_grid() options metric_arg_list=[], # MetricGroup Options tf_config={}, # tflib.init_tf() options eval=False, # Evaluation mode train=False, # Training mode # Data data_dir=None, # Directory to load datasets from total_kimg=25000, # Total length of the training, measured in thousands of real images mirror_augment=False, # Enable mirror augmentation? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks ratio=1.0, # Image height/width ratio in the dataset # Optimization minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters lazy_regularization=True, # Perform regularization as a separate training step? smoothing_kimg=10.0, # Half-life of the running average of generator weights clip=None, # Clip gradients threshold # Resumption resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning # Affects reporting and training schedule resume_time=0.0, # Assumed wallclock time at the beginning, affects reporting recompile=False, # Recompile network from source code (otherwise loads from snapshot) # Logging summarize=True, # Create TensorBoard summaries save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? img_snapshot_ticks=3, # How often to save image snapshots? None = disable network_snapshot_ticks=3, # How often to save network snapshots? None = only save networks-final.pkl last_snapshots=10, # Maximal number of prior snapshots to save eval_images_num=50000, # Sample size for the metrics printname="", # Experiment name for logging # Architecture merge=False): # Generate several images and then merge them # Initialize dnnlib and TensorFlow tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus cG.name, cD.name = "g", "d" # Load dataset, configure training scheduler and metrics object dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) sched = training_schedule(sched_args, cur_nimg=total_kimg * 1000, dataset=dataset) metrics = metric_base.MetricGroup(metric_arg_list) # Construct or load networks with tf.device("/gpu:0"): no_op = tf.no_op() G, D, Gs = None, None, None if resume_pkl is None or recompile: misc.log("Constructing networks...", "white") G = tflib.Network("G", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cG.args) D = tflib.Network("D", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cD.args) Gs = G.clone("Gs") if resume_pkl is not None: G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile) G.print_layers() D.print_layers() # Train/Evaluate/Visualize # Labels are optional but not essential grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid( dataset, **grid_args) misc.save_img_grid(grid_reals, dnnlib.make_run_dir_path("reals.png"), drange=dataset.dynamic_range, grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) if eval: # Save a snapshot of the current network to evaluate pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" % resume_kimg) misc.save_pkl((G, D, Gs), pkl, remove=False) # Quantitative evaluation metric = metrics.run(pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) # Qualitative evaluation visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, drange_net=drange_net, ratio=ratio, **vis_args) if not train: dataset.close() exit() # Setup training inputs misc.log("Building TensorFlow graph...", "white") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[]) lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[]) step = tf.placeholder(tf.int32, name="step", shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name="minibatch_size_in", shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name="minibatch_gpu_in", shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), smoothing_kimg * 1000.0) if smoothing_kimg > 0.0 else 0.0 # Set optimizers for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]: set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip) # Build training graph for each GPU data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): # Create GPU-specific shadow copies of G and D for cN, N in [(cG, G), (cD, D)]: cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow") Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow") # Fetch training data via temporary variables with tf.name_scope("DataFetch"): reals, labels = dataset.get_minibatch_tf() reals = process_reals(reals, dataset.dynamic_range, drange_net, mirror_augment) reals, reals_fetch = read_data( reals, "reals", [sched.minibatch_gpu] + dataset.shape, minibatch_gpu_in) labels, labels_fetch = read_data( labels, "labels", [sched.minibatch_gpu, dataset.label_size], minibatch_gpu_in) data_fetch_ops += [reals_fetch, labels_fetch] # Evaluate loss functions with tf.name_scope("G_loss"): cG.loss, cG.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, minibatch_size=minibatch_gpu_in, **cG.loss_args) with tf.name_scope("D_loss"): cD.loss, cD.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, labels=labels, minibatch_size=minibatch_gpu_in, **cD.loss_args) for cN in [cG, cD]: set_optimizer_ops(cN, lazy_regularization, no_op) # Setup training ops data_fetch_op = tf.group(*data_fetch_ops) for cN in [cG, cD]: cN.train_op = cN.opt.apply_updates() cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta) # Finalize graph with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # Tensorboard summaries if summarize: misc.log("Initializing logs...", "white") summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() # Initialize training misc.log("Training for %d kimg..." % total_kimg, "white") dnnlib.RunContext.get().update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_tick, running_mb_counter = -1, 0 cur_nimg = int(resume_kimg * 1000) tick_start_nimg = cur_nimg for cN in [cG, cD]: cN.lossvals_agg = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } cN.opt.reset_optimizer_state() # Training loop while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops sched = training_schedule(sched_args, cur_nimg=cur_nimg, dataset=dataset) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 dataset.configure(sched.minibatch_gpu) # Run training ops feed_dict = { lrate_in_g: sched.G_lrate, lrate_in_d: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, step: sched.kimg } # Several iterations before updating training parameters for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) for cN in [cG, cD]: cN.run_reg = lazy_regularization and (running_mb_counter % cN.reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 for cN in [cG, cD]: cN.lossvals = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } # Gradient accumulation for _round in rounds: cG.lossvals.update( tflib.run([cG.train_op, cG.ops], feed_dict)[1]) if cG.run_reg: _, cG.lossvals["reg_norm"] = tflib.run( [cG.reg_op, cG.reg_norm], feed_dict) tflib.run(data_fetch_op, feed_dict) cD.lossvals.update( tflib.run([cD.train_op, cD.ops], feed_dict)[1]) if cD.run_reg: _, cD.lossvals["reg_norm"] = tflib.run( [cD.reg_op, cD.reg_norm], feed_dict) tflib.run([Gs_update_op], feed_dict) # Track loss statistics for cN in [cG, cD]: for k in cN.lossvals_agg: cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k], cN.lossvals[k]) # Perform maintenance tasks once per tick done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress print( ("tick %s kimg %s loss/reg: G (%s %s) D (%s %s) grad norms: G (%s %s) D (%s %s) " + "time %s sec/kimg %s maxGPU %sGB %s") % (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)), misc.bcolored( "{:>8.1f}".format( autosummary("Progress/kimg", cur_nimg / 1000.0)), "red"), misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)), misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)), misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"), misc.bold("%-10s" % dnnlib.util.format_time( autosummary("Timing/total_sec", total_time))), "{:>7.2f}".format( autosummary("Timing/sec_per_kimg", tick_time / tick_kimg)), "{:>4.1f}".format( autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30)), printname)) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots if img_snapshot_ticks is not None and ( cur_tick % img_snapshot_ticks == 0 or done): visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, training=True, step=cur_nimg // 1000, grid_size=grid_size, latents=grid_latents, labels=grid_labels, drange_net=drange_net, ratio=ratio, **vis_args) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl, remove=False) if cur_tick % network_snapshot_ticks == 0 or done: metric = metrics.run( pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) if last_snapshots > 0: misc.rm( sorted( glob.glob(dnnlib.make_run_dir_path( "network*.pkl")))[:-last_snapshots]) # Update summaries and RunContext if summarize: metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update(None, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl"), remove=False) # All done if summarize: summary_log.close() dataset.close()