Example #1
0
def prepare(tasks,
            data_dir,
            shards_num=1,
            max_images=None,
            ratio=1.0,
            images_dir=None,
            format=None):  # Options for custom dataset
    mkdir(data_dir)
    for task in tasks:
        # If task not in catalog, create custom task configuration
        c = catalog.get(
            task, {
                "local": True,
                "name": task,
                "dir": images_dir,
                "ratio": ratio,
                "process": formats_catalog.get(format)
            })

        dirname = "{}/{}".format(data_dir, task)
        mkdir(dirname)

        # try:
        print(misc.bold("Preparing the {} dataset...".format(c.name)))
        fname = "{}/{}".format(dirname, c.filename)

        if "local" not in c:
            download = not ((os.path.exists(fname)
                             and verify_md5(fname, c.md5)))
            path = get_path(c.url, dirname, path=c.filename)

            if download:
                print(
                    misc.bold("Downloading the data ({} GB)...".format(
                        c.size)))
                download_file(c.url, path)
                # print(misc.bold("Completed downloading {}".format(c.name)))

            if path.endswith(".zip"):
                if not is_unzipped(path, dirname):
                    print(misc.bold("Unzipping {}...".format(path)))
                    unzip(path, dirname)
                    # print(misc.bold("Completed unzipping {}".format(path)))

        if "process" in c:
            imgdir = images_dir if "local" in c else ("{}/{}".format(
                dirname, c.dir))
            shards_num = c.shards if max_images is None else shards_num
            c.process(dirname,
                      imgdir,
                      ratio=c.ratio,
                      shards_num=shards_num,
                      max_imgs=max_images)

        print(
            misc.bcolored("Completed preparations for {}!".format(c.name),
                          "blue"))
Example #2
0
def task(task, name, size, dir, redownload, download = None, prepare = lambda: None):
    if task:
        print(misc.bcolored("Preparing the {} dataset...".format(name), "blue"))
        if download is not None and (redownload or not path.exists("{}/{}".format(dir, task))):
            print(misc.bold("Downloading the data ({} GB)...".format(size), "blue"))
            download()
            print(misc("Completed downloading {}".format(name)))
        
        prepare()
        print(misc.bold("Completed preparations for {}!".format(name)))
    except:
Example #3
0
def load_nets(resume_pkl, lG, lD, lGs, recompile):
    print(misc.bcolored("Loading networks from %s..." % resume_pkl, "white"))
    rG, rD, rGs = misc.load_pkl(resume_pkl)[:3]
    if recompile: 
        print(misc.bold("Copying nets...")); 
        lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs)
    else: 
        lG, lD, lGs = rG, rD, rGs
    return lG, lD, lGs
Example #4
0
def verify_md5(filename, md5):
    print(f"Verify MD5 for {filename}...")
    with open(filename, "rb") as f:
        new_md5 = hashlib.md5(f.read()).hexdigest()
    result = md5 == new_md5
    if result:
        print(misc.bold("MD5 matches!"))
    else:
        print("MD5 doesn't match. Will redownload the file.")
    return result
Example #5
0
def prepare(tasks, data_dir, max_images = None, 
        ratio = 1.0, images_dir = None, format = None): # Options for custom dataset
    os.makedirs(data_dir, exist_ok = True)
    for task in tasks:
        # If task not in catalog, create custom task configuration
        c = catalog.get(task, EasyDict({
            "local": True,
            "name": task,
            "dir": images_dir,
            "ratio": ratio,
            "process": formats_catalog.get(format)
        }))

        dirname = f"{data_dir}/{task}"
        os.makedirs(dirname, exist_ok = True)

        # try:
        print(misc.bold(f"Preparing the {c.name} dataset..."))
        
        if "local" not in c:
            fname = f"{dirname}/{c.filename}"
            download = not ((os.path.exists(fname) and verify_md5(fname, c.md5)))
            path = get_path(c.url, dirname, path = c.filename)

            if download:
                print(misc.bold(f"Downloading the data ({c.size} GB)..."))
                download_file(c.url, path)
            
            if path.endswith(".zip"):
                if not is_unzipped(path, dirname):
                    print(misc.bold(f"Unzipping {path}..."))
                    unzip(path, dirname)

        if "process" in c:
            imgdir = images_dir if "local" in c else (f"{dirname}/{c.dir}")
            c.process(dirname, imgdir, ratio = c.ratio, max_imgs = max_images)

        print(misc.bcolored(f"Completed preparations for {c.name}!", "blue"))
Example #6
0
def run(**args):
    args = EasyDict(args)
    train = EasyDict(run_func_name="training.training_loop.training_loop"
                     )  # training loop options
    sched = EasyDict()  # TrainingSchedule options
    vis = EasyDict()  # visualize.eval() options
    grid = EasyDict(size="1080p",
                    layout="random")  # setup_snapshot_img_grid() options
    sc = dnnlib.SubmitConfig()  # dnnlib.submit_run() options

    # Environment configuration
    tf_config = {
        "rnd.np_random_seed": 1000,
        "allow_soft_placement": True,
        "gpu_options.per_process_gpu_memory_fraction": 1.0
    }
    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    # Networks configuration
    cG = set_net("G", reg_interval=4)
    cD = set_net("D", reg_interval=16)

    # Dataset configuration
    ratios = {
        "clevr": 0.75,
        "lsun-bedrooms": 0.72,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = ratios.get(args.dataset, args.ratio)
    dataset_args = EasyDict(tfrecord_dir=args.dataset,
                            max_imgs=args.train_images_num,
                            ratio=args.ratio,
                            num_threads=args.num_threads)
    for arg in ["data_dir", "mirror_augment", "total_kimg"]:
        cset(train, arg, args[arg])

    # Training and Optimizations configuration
    for arg in ["eval", "train", "recompile", "last_snapshots"]:
        cset(train, arg, args[arg])

    # Round to the closest multiply of minibatch size for validity
    args.batch_size -= args.batch_size % args.minibatch_size
    args.minibatch_std_size -= args.minibatch_std_size % args.minibatch_size
    args.latent_size -= args.latent_size % args.components_num
    if args.latent_size == 0:
        print(
            misc.bcolored(
                "Error: latent-size is too small. Must best a multiply of components-num.",
                "red"))
        exit()

    sched_args = {
        "G_lrate": "g_lr",
        "D_lrate": "d_lr",
        "minibatch_size": "batch_size",
        "minibatch_gpu": "minibatch_size"
    }
    for arg, cmd_arg in sched_args.items():
        cset(sched, arg, args[cmd_arg])
    cset(train, "clip", args.clip)

    # Logging and metrics configuration
    metrics = [metric_defaults[x] for x in args.metrics]

    cset(cG.args, "truncation_psi", args.truncation_psi)
    for arg in ["summarize", "keep_samples", "eval_images_num"]:
        cset(train, arg, args[arg])

    # Visualization
    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = {arg for arg in vis_types if args["vis_{}".format(arg)]}

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "intrpolation_density",
        "intrp_per_component": "intrpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks architecture
    cset(cG.args, "architecture", args.g_arch)
    cset(cD.args, "architecture", args.d_arch)
    cset(cG.args, "tanh", args.tanh)

    # Latent sizes
    if args.components_num > 1:
        if not (args.transformer or args.kgan):
            print(
                misc.bcolored(
                    "Error: components-num > 1 but the model is not using components.",
                    "red"))
            print(
                misc.bcolored(
                    "Either add --transformer for GANsformer or --kgan for k-GAN).",
                    "red"))
            exit()
        args.latent_size = int(args.latent_size / args.components_num)
    cD.args.latent_size = cG.args.latent_size = cG.args.dlatent_size = args.latent_size
    cset([cG.args, cD.args, vis], "components_num", args.components_num)
    # mapping-dim default value is the latent-size
    if args.mapping_dim is None:
        args.mapping_dim = args.latent_size

    # Mapping network
    for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]:
        cset(cG.args, arg, args["mapping_{}".format(arg)])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]:
        cset(cG.args, arg, args[arg])
    cD.args.mbstd_group_size = args.minibatch_std_size

    # GANsformer
    cset(cG.args, "transformer", args.transformer)
    cset(cD.args, "transformer", args.d_transformer)

    args.norm = args.normalize
    for arg in [
            "norm", "integration", "ltnt_gate", "img_gate", "kmeans",
            "kmeans_iters", "mapping_ltnt2ltnt"
    ]:
        cset(cG.args, arg, args[arg])

    for arg in ["use_pos", "num_heads"]:
        cset([cG.args, cD.args], arg, args[arg])

    # Positional encoding
    for arg in ["dim", "init", "directions_num"]:
        field = "pos_{}".format(arg)
        cset([cG.args, cD.args], field, args[field])

    # k-GAN
    for arg in ["layer", "type", "same"]:
        field = "merge_{}".format(arg)
        cset(cG.args, field, args[field])
    cset([cG.args, train], "merge", args.kgan)

    # Attention
    for arg in [
            "start_res", "end_res", "ltnt2ltnt", "img2img", "local_attention"
    ]:
        cset(cG.args, arg, args["g_{}".format(arg)])
        cset(cD.args, arg, args["d_{}".format(arg)])
    cset(cG.args, "img2ltnt", args.g_img2ltnt)
    cset(cD.args, "ltnt2img", args.d_ltnt2img)

    # Mixing and dropout
    for arg in [
            "style_mixing", "component_mixing", "component_dropout",
            "attention_dropout"
    ]:
        cset(cG.args, arg, args[arg])

    # Loss and regularization
    gloss_args = {
        "loss_type": "g_loss",
        "reg_weight": "g_reg_weight",
        "pathreg": "pathreg",
    }
    dloss_args = {"loss_type": "d_loss", "reg_type": "d_reg", "gamma": "gamma"}
    for arg, cmd_arg in gloss_args.items():
        cset(cG.loss_args, arg, args[cmd_arg])
    for arg, cmd_arg in dloss_args.items():
        cset(cD.loss_args, arg, args[cmd_arg])

    ##### Experiments management:
    # Whenever we start a new experiment we store its result in a directory named 'args.expname:000'.
    # When we rerun a training or evaluation command it restores the model from that directory by default.
    # If we wish to restart the model training, we can set --restart and then we will store data in a new
    # directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc.

    # Find the latest directory that matches the experiment
    exp_dir = sorted(glob.glob("{}/{}:*".format(args.result_dir,
                                                args.expname)))
    run_id = 0
    if len(exp_dir) > 0:
        run_id = int(exp_dir[-1].split(":")[-1])
    # If restart, then work over a new directory
    if args.restart:
        run_id += 1

    run_name = "{}:{:03d}".format(args.expname, run_id)
    train.printname = "{} ".format(misc.bold(args.expname))

    snapshot, kimg, resume = None, 0, False
    pkls = sorted(
        glob.glob("{}/{}/network*.pkl".format(args.result_dir, run_name)))
    # Load a particular snapshot is specified
    if args.pretrained_pkl:
        # Soft links support
        snapshot = glob.glob(args.pretrained_pkl)[0]
        if os.path.islink(snapshot):
            snapshot = os.readlink(snapshot)

        # Extract training step from the snapshot if specified
        try:
            kimg = int(snapshot.split("-")[-1].split(".")[0])
        except:
            pass

    # Find latest snapshot in the directory
    elif len(pkls) > 0:
        snapshot = pkls[-1]
        kimg = int(snapshot.split("-")[-1].split(".")[0])
        resume = True

    if snapshot:
        print(
            misc.bcolored("Resuming {}, kimg {}".format(snapshot, kimg),
                          "white"))
        train.resume_pkl = snapshot
        train.resume_kimg = kimg
    else:
        print(misc.bcolored("Start model training from scratch.", "white"))

    # Run environment configuration
    sc.run_dir_root = args.result_dir
    sc.run_desc = args.expname
    sc.run_id = run_id
    sc.run_name = run_name
    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True

    kwargs = EasyDict(train)
    kwargs.update(cG=cG, cD=cD)
    kwargs.update(dataset_args=dataset_args,
                  vis_args=vis,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.resume = resume
    kwargs.load_config = args.reload

    dnnlib.submit_run(**kwargs)
Example #7
0
def training_loop(
    # General configuration
    train                   = False,    # Training mode
    eval                    = False,    # Evaluation mode
    vis                     = False,    # Visualization mode        
    run_dir                 = ".",      # Output directory
    num_gpus                = 1,        # Number of GPUs participating in the training
    rank                    = 0,        # Rank of the current process in [0, num_gpus]
    cG                      = {},       # Options for generator network
    cD                      = {},       # Options for discriminator network
    # Data
    dataset_args            = {},       # Options for training set
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks
    # Optimization
    loss_args               = {},       # Options for loss function
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus
    batch_gpu               = 4,        # Number of samples processed at a time by one GPU
    ema_kimg                = 10.0,     # Half-life of the exponential moving average (EMA) of generator weights
    ema_rampup              = None,     # EMA ramp-up coefficient
    cudnn_benchmark         = True,     # Enable torch.backends.cudnnbenchmark?
    allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnnallow_tf32?
    # Logging
    resume_pkl              = None,     # Network pickle to resume training from
    resume_kimg             = 0.0,      # Assumed training progress at the beginning
                                        # Affects reporting and training schedule    
    kimg_per_tick           = 8,        # Progress snapshot interval
    img_snapshot_ticks      = 3,        # How often to save image snapshots? None = disable
    network_snapshot_ticks  = 3,        # How often to save network snapshots? None = disable    
    last_snapshots          = 10,       # Maximal number of prior snapshots to save
    printname               = "",       # Experiment name for logging
    # Evaluation
    vis_args                = {},       # Options for vis.vis
    metrics                 = [],       # Metrics to evaluate during training
    eval_images_num         = 50000,    # Sample size for the metrics
    truncation_psi          = 0.7       # Style strength multiplier for the truncation trick (used for visualizations only)
):  
    # Initialize
    start_time = time.time()
    device = init_cuda(rank, cudnn_benchmark, allow_tf32)
    log = (rank == 0)

    dataset, dataset_iter = load_dataset(dataset_args, batch_size, rank, num_gpus, log) # Load training set

    nets = construct_nets(cG, cD, dataset, device, log) if train else None        # Construct networks
    G, D, Gs = load_nets(resume_pkl, nets, device, log)                           # Resume from existing pickle
    print_nets(G, D, batch_gpu, device, log)                                      # Print network summary tables

    if eval:
        misc.log("Run evaluation...", log = log)
        evaluate(Gs, resume_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log)

    if vis and log:
        misc.log("Produce visualizations...")
        visualize.vis(Gs, dataset, device, batch_gpu, drange_net = drange_net, ratio = dataset.ratio, 
            truncation_psi = truncation_psi, **vis_args)

    if not train:
        exit()

    nets = distribute_nets(G, D, Gs, device, num_gpus, log)                                   # Distribute networks across GPUs
    loss, stages = setup_training_stages(loss_args, G, cG, D, cD, nets, device, log)          # Setup training stages (losses and optimizers)
    grid_size, grid_z, grid_c = init_img_grid(dataset, G.input_shape, device, run_dir, log)   # Initialize an image grid    
    logger = init_logger(run_dir, log)                                                        # Initialize logs

    # Train
    misc.log(f"Training for {total_kimg} kimg...", "white", log)    
    cur_nimg, cur_tick, batch_idx = int(resume_kimg * 1000), 0, 0
    tick_start_nimg, tick_start_time = cur_nimg, time.time()
    stats = None

    while True:
        # Fetch training data
        real_img, real_c, gen_zs, gen_cs = fetch_data(dataset, dataset_iter, G.input_shape, drange_net, 
            device, len(stages), batch_size, batch_gpu)

        # Execute training stages
        for stage, gen_z, gen_c in zip(stages, gen_zs, gen_cs):
            if batch_idx % stage.interval != 0:
                continue
            run_training_stage(loss, stage, device, real_img, real_c, gen_z, gen_c, batch_size, batch_gpu, num_gpus)

        # Update Gs
        update_ema_network(G, Gs, batch_size, cur_nimg, ema_kimg, ema_rampup)

        # Update state
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line and accumulate the info in logger.collector
        tick_end_time = time.time()
        if stats is not None:
            default = dnnlib.EasyDict({'mean': -1})
            fields = []
            fields.append("tick " + misc.bold(f"{training_stats.report0('Progress/tick', cur_tick):<5d}"))
            fields.append("kimg " + misc.bcolored(f"{training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}", "red"))
            fields.append("")
            fields.append("loss/reg: G (" + misc.bcolored(f"{stats.get('Loss/G/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/G/reg', default).mean:>6.3f}") + ")")
            fields.append("D "+ misc.bcolored(f"({stats.get('Loss/D/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/D/reg', default).mean:>6.3f}") + ")")
            fields.append("")
            fields.append("time " + misc.bold(f"{dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"))
            fields.append(f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}")
            fields.append(f"mem: GPU {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}")
            fields.append(f"CPU {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}")
            fields.append(misc.bold(printname))
            torch.cuda.reset_peak_memory_stats()
            training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
            training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
            misc.log(" ".join(fields), log = log)

        # Save image snapshot
        if log and (img_snapshot_ticks is not None) and (done or cur_tick % img_snapshot_ticks == 0):
            visualize.vis(Gs, dataset, device, batch_gpu, training = True,
                step = cur_nimg // 1000, grid_size = grid_size, latents = grid_z, 
                labels = grid_c, drange_net = drange_net, ratio = dataset.ratio, **vis_args)

        # Save network snapshot
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data, snapshot_pkl = save_nets(G, D, Gs, cur_nimg, dataset_args, run_dir, num_gpus > 1, last_snapshots, log)

            # Evaluate metrics
            evaluate(snapshot_data["Gs"], snapshot_pkl, metrics, eval_images_num,
                dataset_args, num_gpus, rank, device, log, logger, run_dir)
            del snapshot_data

        # Collect stats and update logs
        stats = collect_stats(logger, stages)
        update_logger(logger, stats, cur_nimg, start_time)

        cur_tick += 1
        tick_start_nimg, tick_start_time = cur_nimg, time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done
    misc.log("Done!", "blue")
Example #8
0
def training_loop(
    # Configurations
    cG={},
    cD={},  # Generator and Discriminator command-line arguments
    dataset_args={},  # dataset.load_dataset() options
    sched_args={},  # train.TrainingSchedule options
    vis_args={},  # vis.eval options
    grid_args={},  # train.setup_snapshot_img_grid() options
    metric_arg_list=[],  # MetricGroup Options
    tf_config={},  # tflib.init_tf() options
    eval=False,  # Evaluation mode
    train=False,  # Training mode
    # Data
    data_dir=None,  # Directory to load datasets from
    total_kimg=25000,  # Total length of the training, measured in thousands of real images
    mirror_augment=False,  # Enable mirror augmentation?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks
    ratio=1.0,  # Image height/width ratio in the dataset
    # Optimization
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters
    lazy_regularization=True,  # Perform regularization as a separate training step?
    smoothing_kimg=10.0,  # Half-life of the running average of generator weights
    clip=None,  # Clip gradients threshold
    # Resumption
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning
    # Affects reporting and training schedule
    resume_time=0.0,  # Assumed wallclock time at the beginning, affects reporting
    recompile=False,  # Recompile network from source code (otherwise loads from snapshot)
    # Logging
    summarize=True,  # Create TensorBoard summaries
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    img_snapshot_ticks=3,  # How often to save image snapshots? None = disable
    network_snapshot_ticks=3,  # How often to save network snapshots? None = only save networks-final.pkl
    last_snapshots=10,  # Maximal number of prior snapshots to save
    eval_images_num=50000,  # Sample size for the metrics
    printname="",  # Experiment name for logging
    # Architecture
    merge=False):  # Generate several images and then merge them

    # Initialize dnnlib and TensorFlow
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus
    cG.name, cD.name = "g", "d"

    # Load dataset, configure training scheduler and metrics object
    dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                verbose=True,
                                **dataset_args)
    sched = training_schedule(sched_args,
                              cur_nimg=total_kimg * 1000,
                              dataset=dataset)
    metrics = metric_base.MetricGroup(metric_arg_list)

    # Construct or load networks
    with tf.device("/gpu:0"):
        no_op = tf.no_op()
        G, D, Gs = None, None, None
        if resume_pkl is None or recompile:
            misc.log("Constructing networks...", "white")
            G = tflib.Network("G",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cG.args)
            D = tflib.Network("D",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cD.args)
            Gs = G.clone("Gs")
        if resume_pkl is not None:
            G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile)

    G.print_layers()
    D.print_layers()

    # Train/Evaluate/Visualize
    # Labels are optional but not essential
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid(
        dataset, **grid_args)
    misc.save_img_grid(grid_reals,
                       dnnlib.make_run_dir_path("reals.png"),
                       drange=dataset.dynamic_range,
                       grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])

    if eval:
        # Save a snapshot of the current network to evaluate
        pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" %
                                       resume_kimg)
        misc.save_pkl((G, D, Gs), pkl, remove=False)

        # Quantitative evaluation
        metric = metrics.run(pkl,
                             num_imgs=eval_images_num,
                             run_dir=dnnlib.make_run_dir_path(),
                             data_dir=dnnlib.convert_path(data_dir),
                             num_gpus=num_gpus,
                             ratio=ratio,
                             tf_config=tf_config,
                             mirror_augment=mirror_augment)

        # Qualitative evaluation
        visualize.eval(G,
                       dataset,
                       batch_size=sched.minibatch_gpu,
                       drange_net=drange_net,
                       ratio=ratio,
                       **vis_args)

    if not train:
        dataset.close()
        exit()

    # Setup training inputs
    misc.log("Building TensorFlow graph...", "white")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[])
        lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[])
        step = tf.placeholder(tf.int32, name="step", shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name="minibatch_size_in",
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name="minibatch_gpu_in",
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                   tf.float32), smoothing_kimg *
                           1000.0) if smoothing_kimg > 0.0 else 0.0

    # Set optimizers
    for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]:
        set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip)

    # Build training graph for each GPU
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):

            # Create GPU-specific shadow copies of G and D
            for cN, N in [(cG, G), (cD, D)]:
                cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow")
            Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow")

            # Fetch training data via temporary variables
            with tf.name_scope("DataFetch"):
                reals, labels = dataset.get_minibatch_tf()
                reals = process_reals(reals, dataset.dynamic_range, drange_net,
                                      mirror_augment)
                reals, reals_fetch = read_data(
                    reals, "reals", [sched.minibatch_gpu] + dataset.shape,
                    minibatch_gpu_in)
                labels, labels_fetch = read_data(
                    labels, "labels",
                    [sched.minibatch_gpu, dataset.label_size],
                    minibatch_gpu_in)
                data_fetch_ops += [reals_fetch, labels_fetch]

            # Evaluate loss functions
            with tf.name_scope("G_loss"):
                cG.loss, cG.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    minibatch_size=minibatch_gpu_in,
                    **cG.loss_args)

            with tf.name_scope("D_loss"):
                cD.loss, cD.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    labels=labels,
                    minibatch_size=minibatch_gpu_in,
                    **cD.loss_args)

            for cN in [cG, cD]:
                set_optimizer_ops(cN, lazy_regularization, no_op)

    # Setup training ops
    data_fetch_op = tf.group(*data_fetch_ops)
    for cN in [cG, cD]:
        cN.train_op = cN.opt.apply_updates()
        cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta)

    # Finalize graph
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # Tensorboard summaries
    if summarize:
        misc.log("Initializing logs...", "white")
        summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
        if save_tf_graph:
            summary_log.add_graph(tf.get_default_graph())
        if save_weight_histograms:
            G.setup_weight_histograms()
            D.setup_weight_histograms()

    # Initialize training
    misc.log("Training for %d kimg..." % total_kimg, "white")
    dnnlib.RunContext.get().update("",
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()

    cur_tick, running_mb_counter = -1, 0
    cur_nimg = int(resume_kimg * 1000)
    tick_start_nimg = cur_nimg
    for cN in [cG, cD]:
        cN.lossvals_agg = {
            k: None
            for k in ["loss", "reg", "norm", "reg_norm"]
        }
        cN.opt.reset_optimizer_state()

    # Training loop
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops
        sched = training_schedule(sched_args,
                                  cur_nimg=cur_nimg,
                                  dataset=dataset)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        dataset.configure(sched.minibatch_gpu)

        # Run training ops
        feed_dict = {
            lrate_in_g: sched.G_lrate,
            lrate_in_d: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            step: sched.kimg
        }

        # Several iterations before updating training parameters
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            for cN in [cG, cD]:
                cN.run_reg = lazy_regularization and (running_mb_counter %
                                                      cN.reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            for cN in [cG, cD]:
                cN.lossvals = {
                    k: None
                    for k in ["loss", "reg", "norm", "reg_norm"]
                }

            # Gradient accumulation
            for _round in rounds:
                cG.lossvals.update(
                    tflib.run([cG.train_op, cG.ops], feed_dict)[1])
                if cG.run_reg:
                    _, cG.lossvals["reg_norm"] = tflib.run(
                        [cG.reg_op, cG.reg_norm], feed_dict)

                tflib.run(data_fetch_op, feed_dict)

                cD.lossvals.update(
                    tflib.run([cD.train_op, cD.ops], feed_dict)[1])
                if cD.run_reg:
                    _, cD.lossvals["reg_norm"] = tflib.run(
                        [cD.reg_op, cD.reg_norm], feed_dict)

            tflib.run([Gs_update_op], feed_dict)

            # Track loss statistics
            for cN in [cG, cD]:
                for k in cN.lossvals_agg:
                    cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k],
                                                cN.lossvals[k])

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress
            print(
                ("tick %s kimg %s   loss/reg: G (%s %s) D (%s %s)   grad norms: G (%s %s) D (%s %s)   "
                 + "time %s sec/kimg %s maxGPU %sGB %s") %
                (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)),
                 misc.bcolored(
                     "{:>8.1f}".format(
                         autosummary("Progress/kimg", cur_nimg / 1000.0)),
                     "red"),
                 misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)),
                 misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)),
                 misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.bold("%-10s" % dnnlib.util.format_time(
                     autosummary("Timing/total_sec", total_time))),
                 "{:>7.2f}".format(
                     autosummary("Timing/sec_per_kimg",
                                 tick_time / tick_kimg)),
                 "{:>4.1f}".format(
                     autosummary("Resources/peak_gpu_mem_gb",
                                 peak_gpu_mem_op.eval() / 2**30)), printname))

            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots
            if img_snapshot_ticks is not None and (
                    cur_tick % img_snapshot_ticks == 0 or done):
                visualize.eval(G,
                               dataset,
                               batch_size=sched.minibatch_gpu,
                               training=True,
                               step=cur_nimg // 1000,
                               grid_size=grid_size,
                               latents=grid_latents,
                               labels=grid_labels,
                               drange_net=drange_net,
                               ratio=ratio,
                               **vis_args)

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl, remove=False)

                if cur_tick % network_snapshot_ticks == 0 or done:
                    metric = metrics.run(
                        pkl,
                        num_imgs=eval_images_num,
                        run_dir=dnnlib.make_run_dir_path(),
                        data_dir=dnnlib.convert_path(data_dir),
                        num_gpus=num_gpus,
                        ratio=ratio,
                        tf_config=tf_config,
                        mirror_augment=mirror_augment)

                if last_snapshots > 0:
                    misc.rm(
                        sorted(
                            glob.glob(dnnlib.make_run_dir_path(
                                "network*.pkl")))[:-last_snapshots])

            # Update summaries and RunContext
            if summarize:
                metrics.update_autosummaries()
                tflib.autosummary.save_summaries(summary_log, cur_nimg)

            dnnlib.RunContext.get().update(None,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot
    misc.save_pkl((G, D, Gs),
                  dnnlib.make_run_dir_path("network-final.pkl"),
                  remove=False)

    # All done
    if summarize:
        summary_log.close()
    dataset.close()