Example #1
0
def prepare(tasks,
            data_dir,
            shards_num=1,
            max_images=None,
            ratio=1.0,
            images_dir=None,
            format=None):  # Options for custom dataset
    mkdir(data_dir)
    for task in tasks:
        # If task not in catalog, create custom task configuration
        c = catalog.get(
            task, {
                "local": True,
                "name": task,
                "dir": images_dir,
                "ratio": ratio,
                "process": formats_catalog.get(format)
            })

        dirname = "{}/{}".format(data_dir, task)
        mkdir(dirname)

        # try:
        print(misc.bold("Preparing the {} dataset...".format(c.name)))
        fname = "{}/{}".format(dirname, c.filename)

        if "local" not in c:
            download = not ((os.path.exists(fname)
                             and verify_md5(fname, c.md5)))
            path = get_path(c.url, dirname, path=c.filename)

            if download:
                print(
                    misc.bold("Downloading the data ({} GB)...".format(
                        c.size)))
                download_file(c.url, path)
                # print(misc.bold("Completed downloading {}".format(c.name)))

            if path.endswith(".zip"):
                if not is_unzipped(path, dirname):
                    print(misc.bold("Unzipping {}...".format(path)))
                    unzip(path, dirname)
                    # print(misc.bold("Completed unzipping {}".format(path)))

        if "process" in c:
            imgdir = images_dir if "local" in c else ("{}/{}".format(
                dirname, c.dir))
            shards_num = c.shards if max_images is None else shards_num
            c.process(dirname,
                      imgdir,
                      ratio=c.ratio,
                      shards_num=shards_num,
                      max_imgs=max_images)

        print(
            misc.bcolored("Completed preparations for {}!".format(c.name),
                          "blue"))
Example #2
0
def load_dataset(class_name = None, data_dir = None, verbose = False, **kwargs):
    kwargs = dict(kwargs)
    if "tfrecord_dir" in kwargs:
        if class_name is None:
            class_name = __name__ + ".TFRecordDataset"
        if data_dir is not None:
            kwargs["tfrecord_dir"] = os.path.join(data_dir, kwargs["tfrecord_dir"])

    assert class_name is not None
    if verbose:
        print(misc.bcolored("Streaming data using %s %s..." % (class_name, data_dir), "white"))
    dataset = dnnlib.util.get_obj_by_name(class_name)(**kwargs)
    if verbose:
        print("Dataset shape: ", misc.bcolored(np.int32(dataset.shape).tolist(), "blue"))
        print("Dynamic range: ", misc.bcolored(dataset.dynamic_range, "blue"))
        if dataset.label_size > 0:
            print("Label size: ", misc.bcolored(dataset.label_size, "blue"))
    return dataset
Example #3
0
def load_nets(resume_pkl, lG, lD, lGs, recompile):
    print(misc.bcolored("Loading networks from %s..." % resume_pkl, "white"))
    rG, rD, rGs = misc.load_pkl(resume_pkl)[:3]
    if recompile: 
        print(misc.bold("Copying nets...")); 
        lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs)
    else: 
        lG, lD, lGs = rG, rD, rGs
    return lG, lD, lGs
Example #4
0
def task(task, name, size, dir, redownload, download = None, prepare = lambda: None):
    if task:
        print(misc.bcolored("Preparing the {} dataset...".format(name), "blue"))
        if download is not None and (redownload or not path.exists("{}/{}".format(dir, task))):
            print(misc.bold("Downloading the data ({} GB)...".format(size), "blue"))
            download()
            print(misc("Completed downloading {}".format(name)))
        
        prepare()
        print(misc.bold("Completed preparations for {}!".format(name)))
    except:
Example #5
0
def print_results(jsonl_line):
    print(" " * 100 + "\r")
    if jsonl_line["snapshot_pkl"] is None:
        network_name = "None"
    else:
        network_name = os.path.splitext(
            os.path.basename(jsonl_line["snapshot_pkl"]))[0]
    if len(network_name) > 29:
        network_name = "..." + network_name[-26:]

    result_str = "%-30s" % network_name
    result_str += " time %-12s" % dnnlib.util.format_time(
        jsonl_line["total_time"])
    nums = ""
    for res, value in jsonl_line["results"].items():
        nums += f" {res} {value:10.4f}"
    nums = misc.bcolored(nums, "blue")
    result_str += nums
    print(result_str)
Example #6
0
    def get_result_str(self, screen=False):
        if self._network_pkl is None:
            network_name = "None"
        else:
            network_name = os.path.splitext(os.path.basename(
                self._network_pkl))[0]
        if len(network_name) > 29:
            network_name = "..." + network_name[-26:]

        result_str = "%-30s" % network_name
        result_str += " time %-12s" % dnnlib.util.format_time(self._eval_time)
        nums = ""
        for res in self._results:
            nums += " " + self.name + res.suffix + " "
            nums += res.fmt % res.value
        if screen:
            nums = misc.bcolored(nums, "blue")
        result_str += nums

        return result_str
Example #7
0
def prepare(tasks, data_dir, max_images = None, 
        ratio = 1.0, images_dir = None, format = None): # Options for custom dataset
    os.makedirs(data_dir, exist_ok = True)
    for task in tasks:
        # If task not in catalog, create custom task configuration
        c = catalog.get(task, EasyDict({
            "local": True,
            "name": task,
            "dir": images_dir,
            "ratio": ratio,
            "process": formats_catalog.get(format)
        }))

        dirname = f"{data_dir}/{task}"
        os.makedirs(dirname, exist_ok = True)

        # try:
        print(misc.bold(f"Preparing the {c.name} dataset..."))
        
        if "local" not in c:
            fname = f"{dirname}/{c.filename}"
            download = not ((os.path.exists(fname) and verify_md5(fname, c.md5)))
            path = get_path(c.url, dirname, path = c.filename)

            if download:
                print(misc.bold(f"Downloading the data ({c.size} GB)..."))
                download_file(c.url, path)
            
            if path.endswith(".zip"):
                if not is_unzipped(path, dirname):
                    print(misc.bold(f"Unzipping {path}..."))
                    unzip(path, dirname)

        if "process" in c:
            imgdir = images_dir if "local" in c else (f"{dirname}/{c.dir}")
            c.process(dirname, imgdir, ratio = c.ratio, max_imgs = max_images)

        print(misc.bcolored(f"Completed preparations for {c.name}!", "blue"))
Example #8
0
def run(**args):
    args = EasyDict(args)
    train = EasyDict(run_func_name="training.training_loop.training_loop"
                     )  # training loop options
    sched = EasyDict()  # TrainingSchedule options
    vis = EasyDict()  # visualize.eval() options
    grid = EasyDict(size="1080p",
                    layout="random")  # setup_snapshot_img_grid() options
    sc = dnnlib.SubmitConfig()  # dnnlib.submit_run() options

    # Environment configuration
    tf_config = {
        "rnd.np_random_seed": 1000,
        "allow_soft_placement": True,
        "gpu_options.per_process_gpu_memory_fraction": 1.0
    }
    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    # Networks configuration
    cG = set_net("G", reg_interval=4)
    cD = set_net("D", reg_interval=16)

    # Dataset configuration
    ratios = {
        "clevr": 0.75,
        "lsun-bedrooms": 0.72,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = ratios.get(args.dataset, args.ratio)
    dataset_args = EasyDict(tfrecord_dir=args.dataset,
                            max_imgs=args.train_images_num,
                            ratio=args.ratio,
                            num_threads=args.num_threads)
    for arg in ["data_dir", "mirror_augment", "total_kimg"]:
        cset(train, arg, args[arg])

    # Training and Optimizations configuration
    for arg in ["eval", "train", "recompile", "last_snapshots"]:
        cset(train, arg, args[arg])

    # Round to the closest multiply of minibatch size for validity
    args.batch_size -= args.batch_size % args.minibatch_size
    args.minibatch_std_size -= args.minibatch_std_size % args.minibatch_size
    args.latent_size -= args.latent_size % args.components_num
    if args.latent_size == 0:
        print(
            misc.bcolored(
                "Error: latent-size is too small. Must best a multiply of components-num.",
                "red"))
        exit()

    sched_args = {
        "G_lrate": "g_lr",
        "D_lrate": "d_lr",
        "minibatch_size": "batch_size",
        "minibatch_gpu": "minibatch_size"
    }
    for arg, cmd_arg in sched_args.items():
        cset(sched, arg, args[cmd_arg])
    cset(train, "clip", args.clip)

    # Logging and metrics configuration
    metrics = [metric_defaults[x] for x in args.metrics]

    cset(cG.args, "truncation_psi", args.truncation_psi)
    for arg in ["summarize", "keep_samples", "eval_images_num"]:
        cset(train, arg, args[arg])

    # Visualization
    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = {arg for arg in vis_types if args["vis_{}".format(arg)]}

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "intrpolation_density",
        "intrp_per_component": "intrpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks architecture
    cset(cG.args, "architecture", args.g_arch)
    cset(cD.args, "architecture", args.d_arch)
    cset(cG.args, "tanh", args.tanh)

    # Latent sizes
    if args.components_num > 1:
        if not (args.transformer or args.kgan):
            print(
                misc.bcolored(
                    "Error: components-num > 1 but the model is not using components.",
                    "red"))
            print(
                misc.bcolored(
                    "Either add --transformer for GANsformer or --kgan for k-GAN).",
                    "red"))
            exit()
        args.latent_size = int(args.latent_size / args.components_num)
    cD.args.latent_size = cG.args.latent_size = cG.args.dlatent_size = args.latent_size
    cset([cG.args, cD.args, vis], "components_num", args.components_num)
    # mapping-dim default value is the latent-size
    if args.mapping_dim is None:
        args.mapping_dim = args.latent_size

    # Mapping network
    for arg in ["layersnum", "lrmul", "dim", "resnet", "shared_dim"]:
        cset(cG.args, arg, args["mapping_{}".format(arg)])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "fused_modconv", "local_noise"]:
        cset(cG.args, arg, args[arg])
    cD.args.mbstd_group_size = args.minibatch_std_size

    # GANsformer
    cset(cG.args, "transformer", args.transformer)
    cset(cD.args, "transformer", args.d_transformer)

    args.norm = args.normalize
    for arg in [
            "norm", "integration", "ltnt_gate", "img_gate", "kmeans",
            "kmeans_iters", "mapping_ltnt2ltnt"
    ]:
        cset(cG.args, arg, args[arg])

    for arg in ["use_pos", "num_heads"]:
        cset([cG.args, cD.args], arg, args[arg])

    # Positional encoding
    for arg in ["dim", "init", "directions_num"]:
        field = "pos_{}".format(arg)
        cset([cG.args, cD.args], field, args[field])

    # k-GAN
    for arg in ["layer", "type", "same"]:
        field = "merge_{}".format(arg)
        cset(cG.args, field, args[field])
    cset([cG.args, train], "merge", args.kgan)

    # Attention
    for arg in [
            "start_res", "end_res", "ltnt2ltnt", "img2img", "local_attention"
    ]:
        cset(cG.args, arg, args["g_{}".format(arg)])
        cset(cD.args, arg, args["d_{}".format(arg)])
    cset(cG.args, "img2ltnt", args.g_img2ltnt)
    cset(cD.args, "ltnt2img", args.d_ltnt2img)

    # Mixing and dropout
    for arg in [
            "style_mixing", "component_mixing", "component_dropout",
            "attention_dropout"
    ]:
        cset(cG.args, arg, args[arg])

    # Loss and regularization
    gloss_args = {
        "loss_type": "g_loss",
        "reg_weight": "g_reg_weight",
        "pathreg": "pathreg",
    }
    dloss_args = {"loss_type": "d_loss", "reg_type": "d_reg", "gamma": "gamma"}
    for arg, cmd_arg in gloss_args.items():
        cset(cG.loss_args, arg, args[cmd_arg])
    for arg, cmd_arg in dloss_args.items():
        cset(cD.loss_args, arg, args[cmd_arg])

    ##### Experiments management:
    # Whenever we start a new experiment we store its result in a directory named 'args.expname:000'.
    # When we rerun a training or evaluation command it restores the model from that directory by default.
    # If we wish to restart the model training, we can set --restart and then we will store data in a new
    # directory: 'args.expname:001' after the first restart, then 'args.expname:002' after the second, etc.

    # Find the latest directory that matches the experiment
    exp_dir = sorted(glob.glob("{}/{}:*".format(args.result_dir,
                                                args.expname)))
    run_id = 0
    if len(exp_dir) > 0:
        run_id = int(exp_dir[-1].split(":")[-1])
    # If restart, then work over a new directory
    if args.restart:
        run_id += 1

    run_name = "{}:{:03d}".format(args.expname, run_id)
    train.printname = "{} ".format(misc.bold(args.expname))

    snapshot, kimg, resume = None, 0, False
    pkls = sorted(
        glob.glob("{}/{}/network*.pkl".format(args.result_dir, run_name)))
    # Load a particular snapshot is specified
    if args.pretrained_pkl:
        # Soft links support
        snapshot = glob.glob(args.pretrained_pkl)[0]
        if os.path.islink(snapshot):
            snapshot = os.readlink(snapshot)

        # Extract training step from the snapshot if specified
        try:
            kimg = int(snapshot.split("-")[-1].split(".")[0])
        except:
            pass

    # Find latest snapshot in the directory
    elif len(pkls) > 0:
        snapshot = pkls[-1]
        kimg = int(snapshot.split("-")[-1].split(".")[0])
        resume = True

    if snapshot:
        print(
            misc.bcolored("Resuming {}, kimg {}".format(snapshot, kimg),
                          "white"))
        train.resume_pkl = snapshot
        train.resume_kimg = kimg
    else:
        print(misc.bcolored("Start model training from scratch.", "white"))

    # Run environment configuration
    sc.run_dir_root = args.result_dir
    sc.run_desc = args.expname
    sc.run_id = run_id
    sc.run_name = run_name
    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True

    kwargs = EasyDict(train)
    kwargs.update(cG=cG, cD=cD)
    kwargs.update(dataset_args=dataset_args,
                  vis_args=vis,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.resume = resume
    kwargs.load_config = args.reload

    dnnlib.submit_run(**kwargs)
Example #9
0
def main():
    parser = argparse.ArgumentParser(
        description="Train the GANsformer",
        epilog=_examples,
        formatter_class=argparse.RawDescriptionHelpFormatter)

    # Framework
    # ------------------------------------------------------------------------------------------------------

    parser.add_argument("--expname",
                        help="Experiment name",
                        default="exp",
                        type=str)
    parser.add_argument("--eval",
                        help="Evaluation mode (default: False)",
                        default=None,
                        action="store_true")
    parser.add_argument("--train",
                        help="Train mode (default: False)",
                        default=None,
                        metavar="BOOL",
                        type=_str_to_bool)
    parser.add_argument(
        "--gpus",
        help="Comma-separated list of GPUs to be used (default: %(default)s)",
        default="0",
        type=str)

    ## Resumption
    parser.add_argument("--pretrained_pkl",
                        help="Filename for a snapshot to resume (optional)",
                        default=None,
                        type=str)
    parser.add_argument("--restart",
                        help="Restart training from scratch",
                        default=False,
                        action="store_true")
    parser.add_argument(
        "--reload",
        help="Reload options from the original experiment configuration file. "
        +
        "If False, uses the command line arguments when resuming training (default: %(default)s)",
        default=False,
        action="store_true")
    parser.add_argument(
        "--recompile",
        help="Recompile model from source code when resuming training. " +
        "If False, loading modules created when the experiment first started",
        default=None,
        action="store_true")
    parser.add_argument(
        "--last_snapshots",
        help="Number of last snapshots to save. -1 for all (default: 8)",
        default=None,
        type=int)

    ## Dataset
    parser.add_argument("--data-dir",
                        help="Datasets root directory",
                        required=True)
    parser.add_argument(
        "--dataset",
        help="Training dataset name (subdirectory of data-dir).",
        required=True)
    parser.add_argument("--ratio",
                        help="Image height/width ratio in the dataset",
                        default=1.0,
                        type=float)
    parser.add_argument(
        "--num-threads",
        help="Number of input processing threads (default: %(default)s)",
        default=4,
        type=int)
    parser.add_argument(
        "--mirror-augment",
        help=
        "Perform horizontal flip augmentation for the data (default: %(default)s)",
        default=False)
    parser.add_argument(
        "--train-images-num",
        help=
        "Maximum number of images to train on. If not specified, train on the whole dataset.",
        default=None,
        type=int)

    ## Training
    parser.add_argument(
        "--batch-size",
        help="Global batch size (optimization step) (default: %(default)s)",
        default=32,
        type=int)
    parser.add_argument(
        "--minibatch-size",
        help=
        "Batch size per GPU, gradients will be accumulated to match batch-size (default: %(default)s)",
        default=4,
        type=int)
    parser.add_argument(
        "--total-kimg",
        help="Training length in thousands of images (default: %(default)s)",
        metavar="KIMG",
        default=25000,
        type=int)
    parser.add_argument("--gamma",
                        help="R1 regularization weight (default: %(default)s)",
                        default=10,
                        type=float)
    parser.add_argument("--clip",
                        help="Gradient clipping threshold (optional)",
                        default=None,
                        type=float)
    parser.add_argument("--g-lr",
                        help="Generator learning rate (default: %(default)s)",
                        default=0.002,
                        type=float)
    parser.add_argument(
        "--d-lr",
        help="Discriminator learning rate (default: %(default)s)",
        default=0.002,
        type=float)

    ## Logging and evaluation
    parser.add_argument(
        "--result-dir",
        help="Root directory for experiments (default: %(default)s)",
        default="results",
        metavar="DIR")
    parser.add_argument(
        "--metrics",
        help="Comma-separated list of metrics or none (default: %(default)s)",
        default="fid",
        type=_parse_comma_sep)
    parser.add_argument(
        "--summarize",
        help="Create TensorBoard summaries (default: %(default)s)",
        default=True,
        metavar="BOOL",
        type=_str_to_bool)
    parser.add_argument(
        "--truncation-psi",
        help="Truncation Psi to be used in producing sample images " +
        "(used only for visualizations, _not used_ in training or for computing metrics) (default: %(default)s)",
        default=0.65,
        type=float)
    parser.add_argument(
        "--keep-samples",
        help=
        "Keep all prior samples during training, or if False, just the most recent ones (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--eval-images-num",
        help="Number of images to evaluate metrics on (default: 50,000)",
        default=None,
        type=float)

    ## Visualization
    parser.add_argument("--vis-images",
                        help="Save image samples",
                        default=None,
                        action="store_true")
    parser.add_argument("--vis-latents",
                        help="Save latent vectors",
                        default=None,
                        action="store_true")
    parser.add_argument("--vis-maps",
                        help="Save attention maps (for GANsformer only)",
                        default=None,
                        action="store_true")
    parser.add_argument(
        "--vis-layer-maps",
        help="Save attention maps for all layers (for GANsformer only)",
        default=None,
        action="store_true")
    parser.add_argument("--vis-interpolations",
                        help="Create latent interpolations",
                        default=None,
                        action="store_true")
    parser.add_argument("--vis-noise-var",
                        help="Create noise variation visualization",
                        default=None,
                        action="store_true")
    parser.add_argument("--vis-style-mix",
                        help="Create style mixing visualization",
                        default=None,
                        action="store_true")

    parser.add_argument(
        "--vis-grid",
        help=
        "Whether to save the samples in one large grid files (default: True in training)",
        default=None,
        action="store_true")
    parser.add_argument("--vis-num",
                        help="Image height/width ratio in the dataset",
                        default=None,
                        type=int)
    parser.add_argument(
        "--vis-rich-num",
        help=
        "Number of samples for which richer visualizations will be created (default: 5)",
        default=None,
        type=int)
    parser.add_argument(
        "--vis-section-size",
        help=
        "Visualization section size to process at one (section-size <= vis-num) for memory footprint (default: 100)",
        default=None,
        type=int)
    parser.add_argument(
        "--blending-alpha",
        help=
        "Proportion for generated images and attention maps blends (default: 0.3)",
        default=None,
        type=float)
    parser.add_argument(
        "--intrpolation-density",
        help=
        "Number of samples in between two end points of an interpolation (default: 8)",
        default=None,
        type=int)
    parser.add_argument(
        "--intrpolation-per-component",
        help=
        "Whether to perform interpolation along particular latent components when true, or all of them at once otherwise (default: False)",
        default=None,
        action="store_true")

    # Model
    # ------------------------------------------------------------------------------------------------------

    ## General architecture
    parser.add_argument("--g-arch",
                        help="Generator architecture type (default: skip)",
                        default=None,
                        choices=["orig", "skip", "resnet"],
                        type=str)
    parser.add_argument(
        "--d-arch",
        help="Discriminator architecture type (default: resnet)",
        default=None,
        choices=["orig", "skip", "resnet"],
        type=str)
    parser.add_argument("--tanh",
                        help="tanh on generator output (default: False)",
                        default=None,
                        action="store_true")

    # Mapping network
    parser.add_argument("--mapping-layersnum",
                        help="Number of mapping layers (default: 8)",
                        default=None,
                        type=int)
    parser.add_argument(
        "--mapping-lrmul",
        help="Mapping network learning rate multiplier (default: 0.01)",
        default=None,
        type=float)
    parser.add_argument("--mapping-dim",
                        help="Mapping layers dimension (default: latent_size)",
                        default=None,
                        type=int)
    parser.add_argument(
        "--mapping-resnet",
        help="Use resent connections in mapping layers (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--mapping-shared-dim",
        help=
        "Perform one shared mapping to all latent components concatenated together using the set dimension (default: disabled)",
        default=None,
        type=int)

    # Loss
    parser.add_argument(
        "--pathreg",
        help="Use path regularization in generator training (default: False",
        default=None,
        metavar="BOOL",
        type=_str_to_bool)
    parser.add_argument("--g-loss",
                        help="Generator loss type (default: %(default)s)",
                        default="logistic_ns",
                        choices=["logistic", "logistic_ns", "hinge", "wgan"],
                        type=str)
    parser.add_argument(
        "--g-reg-weight",
        help="Generator regularization weight (default: %(default)s)",
        default=1.0,
        type=float)

    parser.add_argument("--d-loss",
                        help="Discriminator loss type (default: %(default)s)",
                        default="logistic",
                        choices=["wgan", "logistic", "hinge"],
                        type=str)
    parser.add_argument(
        "--d-reg",
        help="Discriminator regularization type (default: %(default)s)",
        default="r1",
        choices=["non", "gp", "r1", "r2"],
        type=str)

    # Mixing and dropout
    parser.add_argument(
        "--style-mixing",
        help="Style mixing (layerwise) probability (default: %(default)s)",
        default=0.9,
        type=float)
    parser.add_argument(
        "--component-mixing",
        help="Component mixing (objectwise) probability (default: %(default)s)",
        default=0.0,
        type=float)
    parser.add_argument("--component-dropout",
                        help="Component dropout (default: %(default)s)",
                        default=0.0,
                        type=float)
    parser.add_argument("--attention-dropout",
                        help="Attention dropout (default: 0.12)",
                        default=None,
                        type=float)

    # StyleGAN additions
    parser.add_argument("--style",
                        help="Global style modulation (default: %(default)s)",
                        default=True,
                        metavar="BOOL",
                        type=_str_to_bool)
    parser.add_argument(
        "--latent-stem",
        help="Input latent through the generator stem grid (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--fused-modconv",
        help=
        "Fuse modulation and convolution operations (default: %(default)s)",
        default=True,
        metavar="BOOL",
        type=_str_to_bool)
    parser.add_argument(
        "--local-noise",
        help="Add stochastic local noise each layer (default: %(default)s)",
        default=True,
        metavar="BOOL",
        type=_str_to_bool)
    parser.add_argument(
        "--minibatch-std-size",
        help=
        "Add minibatch standard deviation layer in the discriminator (default: %(default)s)",
        default=4,
        type=int)

    ## GANsformer
    parser.add_argument(
        "--transformer",
        help=
        "Add transformer layers to the generator: top-down latents-to-image (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--latent-size",
        help=
        "Latent size, summing the dimension of all components (default: %(default)s)",
        default=512,
        type=int)
    parser.add_argument(
        "--components-num",
        help=
        "Components number. Each component has latent dimension of 'latent-size / components-num'. "
        +
        "1 for StyleGAN since it has one global latent vector (default: %(default)s)",
        default=1,
        type=int)
    parser.add_argument(
        "--num-heads",
        help="Number of attention heads (default: %(default)s)",
        default=1,
        type=int)
    parser.add_argument("--normalize",
                        help="Feature normalization type (optional)",
                        default=None,
                        choices=["batch", "instance", "layer"])
    parser.add_argument(
        "--integration",
        help=
        "Feature integration type: additive, multiplicative or both (default: %(default)s)",
        default="add",
        choices=["add", "mul", "both"],
        type=str)

    # Generator attention layers
    # Transformer resolution layers
    parser.add_argument(
        "--g-start-res",
        help=
        "Transformer minimum generator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)",
        default=0,
        type=int)
    parser.add_argument(
        "--g-end-res",
        help=
        "Transformer maximum generator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)",
        default=7,
        type=int)

    # Discriminator attention layers
    parser.add_argument(
        "--d-transformer",
        help=
        "Add transformer layers to the discriminator (bottom-up image-to-latents) (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--d-start-res",
        help=
        "Transformer minimum discriminator resolution (logarithmic): first layer in which transformer will be applied (default: %(default)s)",
        default=0,
        type=int)
    parser.add_argument(
        "--d-end-res",
        help=
        "Transformer maximum discriminator resolution (logarithmic): last layer in which transformer will be applied (default: %(default)s)",
        default=7,
        type=int)

    # Attention
    parser.add_argument(
        "--ltnt-gate",
        help=
        "Gate attention from latents, such that components may not send information "
        + "when gate value is low (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--img-gate",
        help=
        "Gate attention for images, such that some image positions may not get updated "
        + "or receive information when gate value is low (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--kmeans",
        help=
        "Track and update image-to-latents assignment centroids, used in the duplex attention (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--kmeans-iters",
        help=
        "Number of K-means iterations per transformer layer. Note that centroids are carried from layer to layer (default: %(default)s)",
        default=1,
        type=int)  # -per-layer

    # Attention directions
    # format is A2B: Elements _from_ B attend _to_ elements in A, and B elements get updated accordingly.
    # Note that it means that information propagates in the following direction: A -> B
    parser.add_argument(
        "--mapping-ltnt2ltnt",
        help=
        "Add self-attention over latents in the mapping network (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--g-ltnt2ltnt",
        help=
        "Add self-attention over latents in the synthesis network (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--g-img2img",
        help=
        "Add self-attention between images positions in that layer of the generator (SAGAN) (default: disabled)",
        default=0,
        type=int)
    parser.add_argument(
        "--g-img2ltnt",
        help=
        "Add image to latents attention (bottom-up) (default: %(default)s)",
        default=None,
        action="store_true")
    # g-ltnt2img: default information flow direction when using --transformer

    parser.add_argument(
        "--d-ltnt2img",
        help="Add latents to image attention (top-down) (default: %(default)s)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--d-ltnt2ltnt",
        help=
        "Add self-attention over latents in the discriminator (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--d-img2img",
        help=
        "Add self-attention over images positions in that layer of the discriminator (SAGAN) (default: disabled)",
        default=0,
        type=int)
    # d-img2ltnt: default information flow direction when using --d-transformer

    # Local attention operations (replacing convolution)
    parser.add_argument(
        "--g-local-attention",
        help=
        "Local attention operations in the generation up to this layer (default: disabled)",
        default=None,
        type=int)
    parser.add_argument(
        "--d-local-attention",
        help=
        "Local attention operations in the discriminator up to this layer (default: disabled)",
        default=None,
        type=int)

    # Positional encoding
    parser.add_argument("--use-pos",
                        help="Use positional encoding (default: False)",
                        default=None,
                        action="store_true")
    parser.add_argument(
        "--pos-dim",
        help="Positional encoding dimension (default: latent-size)",
        default=None,
        type=int)
    parser.add_argument(
        "--pos-type",
        help="Positional encoding type (default: %(default)s)",
        default="sinus",
        choices=["linear", "sinus", "trainable", "trainable2d"],
        type=str)
    parser.add_argument(
        "--pos-init",
        help=
        "Positional encoding initialization distribution (default: %(default)s)",
        default="uniform",
        choices=["uniform", "normal"],
        type=str)
    parser.add_argument(
        "--pos-directions-num",
        help=
        "Positional encoding number of spatial directions (default: %(default)s)",
        default=2,
        type=int)

    ## k-GAN
    parser.add_argument(
        "--kgan",
        help=
        "Generate components-num images and then merge them (k-GAN) (default: False)",
        default=None,
        action="store_true")
    parser.add_argument(
        "--merge-layer",
        help=
        "Merge layer, where images get combined through alpha-composition (default: %(default)s)",
        default=-1,
        type=int)
    parser.add_argument("--merge-type",
                        help="Merge type (default: additive)",
                        default=None,
                        choices=["sum", "softmax", "max", "leaves"],
                        type=str)
    parser.add_argument(
        "--merge-same",
        help=
        "Merge images with same alpha weights across all spatial positions (default: %(default)s)",
        default=None,
        action="store_true")

    args = parser.parse_args()

    if not os.path.exists(args.data_dir):
        print(
            misc.bcolored("Error: dataset root directory does not exist.",
                          "red"))
        exit()

    for metric in args.metrics:
        if metric not in metric_defaults:
            print(misc.bcolored("Error: unknown metric \"%s\"" % metric,
                                "red"))
            exit()

    run(**vars(args))
Example #10
0
    def __init__(
            self,
            tfrecord_dir,  # Directory containing a collection of tfrecords files
            resolution=None,  # Dataset resolution, None = autodetect
            label_file=None,  # Relative path of the labels file, None = autodetect
            max_label_size=0,  # 0 = no labels, "full" = full labels, <int> = N first label components
            max_imgs=None,  # Maximum number of images to use, None = use all images
            repeat=True,  # Repeat dataset indefinitely?
            shuffle_mb=2048,  # Shuffle data within specified window (megabytes), 0 = disable shuffling
            prefetch_mb=512,  # Amount of data to prefetch (megabytes), 0 = disable prefetching
            buffer_mb=256,  # Read buffer size (megabytes)
            num_threads=4,  # Number of concurrent threads for input processing
            **kwargs):

        self.tfrecord_dir = tfrecord_dir
        self.resolution = None
        self.resolution_log2 = None
        self.shape = []  # [channels, height, width]
        self.dtype = "uint8"
        self.dynamic_range = [0, 255]
        self.label_file = label_file
        self.label_size = None
        self.label_dtype = None
        self._np_labels = None
        self._tf_minibatch_in = None
        self._tf_labels_var = None
        self._tf_labels_dataset = None
        self._tf_datasets = dict()
        self._tf_iterator = None
        self._tf_init_ops = dict()
        self._tf_minibatch_np = None
        self._cur_minibatch = -1
        self._cur_lod = -1

        # List tfrecords files and inspect their shapes
        assert os.path.isdir(self.tfrecord_dir)
        tfr_files = sorted(
            glob.glob(os.path.join(self.tfrecord_dir, "*.tfrecords1of*")))
        # If max_imgs is not None, take a subset of images out of the 1st file. Otherwise take all files.
        if max_imgs is None:
            tfr_files = [
                sorted(glob.glob(re.sub("1of.*", "*", f))) for f in tfr_files
            ]
        else:
            tfr_files = [[f] for f in tfr_files]

        assert len(tfr_files) >= 1
        tfr_shapes = []
        for tfr_file in tfr_files:
            tfr_opt = tf.io.TFRecordOptions("")
            for record in tf.python_io.tf_record_iterator(
                    tfr_file[0], tfr_opt):
                tfr_shapes.append(self.parse_tfrecord_np(record).shape)
                break
            random.shuffle(tfr_file)

        # Autodetect label filename
        if self.label_file is None:
            guess = sorted(
                glob.glob(os.path.join(self.tfrecord_dir, "*.labels")))
            if len(guess):
                self.label_file = guess[0]
        elif not os.path.isfile(self.label_file):
            guess = os.path.join(self.tfrecord_dir, self.label_file)
            if os.path.isfile(guess):
                self.label_file = guess

        # Determine shape and resolution
        max_shape = max(tfr_shapes, key=np.prod)
        self.resolution = resolution if resolution is not None else max_shape[1]
        self.resolution_log2 = int(np.log2(self.resolution))
        self.shape = [max_shape[0], self.resolution, self.resolution]
        tfr_lods = [
            self.resolution_log2 - int(np.log2(shape[1]))
            for shape in tfr_shapes
        ]
        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)
        assert all(shape[1] == shape[2] for shape in tfr_shapes)
        assert all(shape[1] == self.resolution // (2**lod)
                   for shape, lod in zip(tfr_shapes, tfr_lods))
        assert all(lod in range(self.resolution_log2 - 1) for lod in tfr_lods)

        # Load labels
        assert max_label_size == "full" or max_label_size >= 0
        self._np_labels = np.zeros([1 << 20, 0], dtype=np.float32)
        if self.label_file is not None and max_label_size != 0:
            self._np_labels = np.load(self.label_file)
            assert self._np_labels.ndim == 2
        if max_label_size != "full" and self._np_labels.shape[
                1] > max_label_size:
            self._np_labels = self._np_labels[:, :max_label_size]
        if max_imgs is not None and self._np_labels.shape[0] > max_imgs:
            self._np_labels = self._np_labels[:max_imgs]
        if max_imgs is not None and self._np_labels.shape[0] < max_imgs:
            print(misc.bcolored("Too many images. increase number.", "red"))
            exit()
        self.label_size = self._np_labels.shape[1]
        self.label_dtype = self._np_labels.dtype.name

        # Build TF expressions
        with tf.name_scope("Dataset"), tf.device("/cpu:0"):
            self._tf_minibatch_in = tf.placeholder(tf.int64,
                                                   name="minibatch_in",
                                                   shape=[])
            self._tf_labels_var = tflib.create_var_with_large_initial_value(
                self._np_labels, name="labels_var")
            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(
                self._tf_labels_var)
            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes,
                                                    tfr_lods):
                if tfr_lod < 0:
                    continue
                # Load dataset
                dset = tf.data.TFRecordDataset(tfr_file,
                                               compression_type="",
                                               buffer_size=buffer_mb << 20,
                                               num_parallel_reads=num_threads)

                # If max_imgs is set, take a subset of the data
                if max_imgs is not None:
                    dset = dset.take(max_imgs)

                # Parse the TF records
                dset = dset.map(self.parse_tfrecord_tf,
                                num_parallel_calls=num_threads)

                # Zip images with their labels (0s if no labels)
                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))

                # Shuffle and repeat
                bytes_per_item = np.prod(tfr_shape) * np.dtype(
                    self.dtype).itemsize
                if shuffle_mb > 0:
                    dset = dset.shuffle((
                        (shuffle_mb << 20) - 1) // bytes_per_item + 1)
                if repeat:
                    dset = dset.repeat()

                # Prefetch and batch
                if prefetch_mb > 0:
                    dset = dset.prefetch((
                        (prefetch_mb << 20) - 1) // bytes_per_item + 1)
                dset = dset.batch(self._tf_minibatch_in)
                self._tf_datasets[tfr_lod] = dset

            # Initialize data iterator
            self._tf_iterator = tf.data.Iterator.from_structure(
                self._tf_datasets[0].output_types,
                self._tf_datasets[0].output_shapes)
            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) \
                for lod, dset in self._tf_datasets.items()}
Example #11
0
def training_loop(
    # General configuration
    train                   = False,    # Training mode
    eval                    = False,    # Evaluation mode
    vis                     = False,    # Visualization mode        
    run_dir                 = ".",      # Output directory
    num_gpus                = 1,        # Number of GPUs participating in the training
    rank                    = 0,        # Rank of the current process in [0, num_gpus]
    cG                      = {},       # Options for generator network
    cD                      = {},       # Options for discriminator network
    # Data
    dataset_args            = {},       # Options for training set
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks
    # Optimization
    loss_args               = {},       # Options for loss function
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus
    batch_gpu               = 4,        # Number of samples processed at a time by one GPU
    ema_kimg                = 10.0,     # Half-life of the exponential moving average (EMA) of generator weights
    ema_rampup              = None,     # EMA ramp-up coefficient
    cudnn_benchmark         = True,     # Enable torch.backends.cudnnbenchmark?
    allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnnallow_tf32?
    # Logging
    resume_pkl              = None,     # Network pickle to resume training from
    resume_kimg             = 0.0,      # Assumed training progress at the beginning
                                        # Affects reporting and training schedule    
    kimg_per_tick           = 8,        # Progress snapshot interval
    img_snapshot_ticks      = 3,        # How often to save image snapshots? None = disable
    network_snapshot_ticks  = 3,        # How often to save network snapshots? None = disable    
    last_snapshots          = 10,       # Maximal number of prior snapshots to save
    printname               = "",       # Experiment name for logging
    # Evaluation
    vis_args                = {},       # Options for vis.vis
    metrics                 = [],       # Metrics to evaluate during training
    eval_images_num         = 50000,    # Sample size for the metrics
    truncation_psi          = 0.7       # Style strength multiplier for the truncation trick (used for visualizations only)
):  
    # Initialize
    start_time = time.time()
    device = init_cuda(rank, cudnn_benchmark, allow_tf32)
    log = (rank == 0)

    dataset, dataset_iter = load_dataset(dataset_args, batch_size, rank, num_gpus, log) # Load training set

    nets = construct_nets(cG, cD, dataset, device, log) if train else None        # Construct networks
    G, D, Gs = load_nets(resume_pkl, nets, device, log)                           # Resume from existing pickle
    print_nets(G, D, batch_gpu, device, log)                                      # Print network summary tables

    if eval:
        misc.log("Run evaluation...", log = log)
        evaluate(Gs, resume_pkl, metrics, eval_images_num, dataset_args, num_gpus, rank, device, log)

    if vis and log:
        misc.log("Produce visualizations...")
        visualize.vis(Gs, dataset, device, batch_gpu, drange_net = drange_net, ratio = dataset.ratio, 
            truncation_psi = truncation_psi, **vis_args)

    if not train:
        exit()

    nets = distribute_nets(G, D, Gs, device, num_gpus, log)                                   # Distribute networks across GPUs
    loss, stages = setup_training_stages(loss_args, G, cG, D, cD, nets, device, log)          # Setup training stages (losses and optimizers)
    grid_size, grid_z, grid_c = init_img_grid(dataset, G.input_shape, device, run_dir, log)   # Initialize an image grid    
    logger = init_logger(run_dir, log)                                                        # Initialize logs

    # Train
    misc.log(f"Training for {total_kimg} kimg...", "white", log)    
    cur_nimg, cur_tick, batch_idx = int(resume_kimg * 1000), 0, 0
    tick_start_nimg, tick_start_time = cur_nimg, time.time()
    stats = None

    while True:
        # Fetch training data
        real_img, real_c, gen_zs, gen_cs = fetch_data(dataset, dataset_iter, G.input_shape, drange_net, 
            device, len(stages), batch_size, batch_gpu)

        # Execute training stages
        for stage, gen_z, gen_c in zip(stages, gen_zs, gen_cs):
            if batch_idx % stage.interval != 0:
                continue
            run_training_stage(loss, stage, device, real_img, real_c, gen_z, gen_c, batch_size, batch_gpu, num_gpus)

        # Update Gs
        update_ema_network(G, Gs, batch_size, cur_nimg, ema_kimg, ema_rampup)

        # Update state
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line and accumulate the info in logger.collector
        tick_end_time = time.time()
        if stats is not None:
            default = dnnlib.EasyDict({'mean': -1})
            fields = []
            fields.append("tick " + misc.bold(f"{training_stats.report0('Progress/tick', cur_tick):<5d}"))
            fields.append("kimg " + misc.bcolored(f"{training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}", "red"))
            fields.append("")
            fields.append("loss/reg: G (" + misc.bcolored(f"{stats.get('Loss/G/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/G/reg', default).mean:>6.3f}") + ")")
            fields.append("D "+ misc.bcolored(f"({stats.get('Loss/D/loss', default).mean:>6.3f}", "blue"))
            fields.append(misc.bold(f"{stats.get('Loss/D/reg', default).mean:>6.3f}") + ")")
            fields.append("")
            fields.append("time " + misc.bold(f"{dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"))
            fields.append(f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}")
            fields.append(f"mem: GPU {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}")
            fields.append(f"CPU {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}")
            fields.append(misc.bold(printname))
            torch.cuda.reset_peak_memory_stats()
            training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
            training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
            misc.log(" ".join(fields), log = log)

        # Save image snapshot
        if log and (img_snapshot_ticks is not None) and (done or cur_tick % img_snapshot_ticks == 0):
            visualize.vis(Gs, dataset, device, batch_gpu, training = True,
                step = cur_nimg // 1000, grid_size = grid_size, latents = grid_z, 
                labels = grid_c, drange_net = drange_net, ratio = dataset.ratio, **vis_args)

        # Save network snapshot
        if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
            snapshot_data, snapshot_pkl = save_nets(G, D, Gs, cur_nimg, dataset_args, run_dir, num_gpus > 1, last_snapshots, log)

            # Evaluate metrics
            evaluate(snapshot_data["Gs"], snapshot_pkl, metrics, eval_images_num,
                dataset_args, num_gpus, rank, device, log, logger, run_dir)
            del snapshot_data

        # Collect stats and update logs
        stats = collect_stats(logger, stages)
        update_logger(logger, stats, cur_nimg, start_time)

        cur_tick += 1
        tick_start_nimg, tick_start_time = cur_nimg, time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done
    misc.log("Done!", "blue")
Example #12
0
def eval(
    G,
    dataset,  # The dataset object for accessing the data
    batch_size,  # Visualization batch size
    training=False,  # Training mode
    latents=None,  # Source latents to generate images from
    labels=None,  # Source labels to generate images from (0 if no labels are used)
    # Model settings
    components_num=1,  # Number of components the model has
    drange_net=[-1, 1],  # Model image output range
    attention=False,  # Whereas the model produces attention maps (for visualization)
    # Visualization settings
    vis_types=None,  # Visualization types to be created
    num=100,  # Number of produced samples
    rich_num=5,  # Number of samples for which richer visualizations will be created
    # (requires more memory and disk space, and therefore rich_num < num)
    grid=None,  # Whether to save the samples in one large grid files
    # or in separated files one per sample
    grid_size=None,  # Grid proportions (w, h)
    step=None,  # Step number to be used in visualization filenames
    verbose=None,  # Verbose print progress messages
    # Visualization-specific settings
    alpha=0.3,  # Proportion for generated images and attention maps blends
    intrp_density=8,  # Number of samples in between two end points of an interpolation
    intrp_per_component=False,  # Whether to perform interpolation along particular latent components (True)
    # or all of them at once (False)
    noise_samples_num=100,  # Number of samples used to compute noise variation visualization
    section_size=100
):  # Visualization section size (section_size <= num) for reducing memory footprint
    def pattern_of(dir, step, suffix):
        return "eval/{}/{}%06d.{}".format(
            dir, "" if step is None else "{}_".format(step), suffix)

    # For time efficiency, during training save only image and map samples
    # rather than richer visualizations
    vis = vis_types
    if training:
        vis = {"imgs", "maps"}
        section_size = num = len(latents)
    else:
        if vis is None:
            vis = {"imgs", "maps", "ltnts", "interpolations", "noise_var"}

    # Set default options
    # Save image samples in one grid file during training
    if grid is None:
        grid = training
    # Disable verbose during training
    if verbose:
        verbose = not training
    # If grid size is provided, set number of visualized images accordingly
    if grid_size is not None:
        num = np.prod(grid_size)

    # build image functions
    save_images = misc.save_images_builder(drange_net, grid_size, grid,
                                           verbose)
    save_blends = misc.save_blends_builder(drange_net, grid_size, grid,
                                           verbose, alpha)

    # Set up logging
    noise_vars = [
        var for name, var in G.subnets.synthesis.vars.items()
        if name.startswith("noise")
    ]
    noise_var_vals = {
        var: np.random.randn(*var.shape.as_list())
        for var in noise_vars
    }
    tflib.set_vars(noise_var_vals)

    # Create directories
    dirs = []
    if "imgs" in vis: dirs += ["images"]
    if "ltnts" in vis: dirs += ["latents-z", "latents-w"]
    if "maps" in vis: dirs += ["maps", "softmaps", "blends", "softblends"]
    if "layer_maps" in vis: dirs += ["layer_maps"]
    if "interpolations" in vis: dirs += ["interpolations-z", "interpolation-w"]

    for dir in dirs:
        misc.mkdir(dnnlib.make_run_dir_path("eval/{}".format(dir)))

    # Produce visualizations
    for idx in range(0, num, section_size):
        curr_size = curr_batch_size(num, idx, section_size)
        if verbose and num > curr_size:
            print("--- Batch {}/{}".format(idx + 1, num))

        # Compute source latents images will be produced from
        if latents is None:
            latents = np.random.randn(curr_size, *G.input_shape[1:])
        if labels is None:
            labels = dataset.get_minibatch_np(curr_size)

        # Run network over latents and produce images and attention maps
        if verbose:
            print("Running network...")
        images, attmaps_all_layers, wlatents_all_layers = G.run(
            latents,
            labels,
            randomize_noise=False,
            minibatch_size=batch_size,
            return_dlatents=True)  # is_visualization = True
        # For memory efficiency, save full information only for a small amount of images
        attmaps_all_layers = attmaps_all_layers[:rich_num]
        wlatents = wlatents_all_layers[:, :, 0]

        # Save image samples
        if "imgs" in vis:
            if verbose:
                print("Saving image samples...")
            save_images(images, pattern_of("images", step, "png"), idx)

        # Save latent vectors
        if "ltnts" in vis:
            if verbose:
                print("Saving latents...")
            misc.save_npys(latents, pattern_of("latents-z", step, "npy"), idx)
            misc.save_npys(wlatents, pattern_of("latents-w", step, "npy"), idx)

        # For the GANsformer model, save attention maps
        if attention:
            if "maps" in vis:
                soft_maps = attmaps_all_layers[:, :, -1, 0]

                pallete = np.expand_dims(misc.get_colors(components_num),
                                         axis=[2, 3])
                maps = (soft_maps == np.amax(soft_maps, axis=1,
                                             keepdims=True)).astype(float)

                soft_maps = np.sum(pallete * np.expand_dims(soft_maps, axis=2),
                                   axis=1)
                maps = np.sum(pallete * np.expand_dims(maps, axis=2), axis=1)

                if verbose:
                    print("Saving maps...")
                save_images(soft_maps, pattern_of("softmaps", step, "png"),
                            idx)
                save_images(maps, pattern_of("maps", step, "png"), idx)

                save_blends(maps, images, pattern_of("softblends", step,
                                                     "png"), idx)
                save_blends(soft_maps, images,
                            pattern_of("blends", step, "png"), idx)

            # Save maps from all attention heads and layers
            # (for efficiency, only for a small number of images)
            if "layer_maps" in vis:
                all_maps = []
                maps_fakes = np.split(attmaps_all_layers,
                                      attmaps_all_layers.shape[2],
                                      axis=2)
                for layer, lmaps in enumerate(maps_fakes):
                    lmaps = np.split(np.squeeze(lmaps, axis=2),
                                     mapfakes.shape[3],
                                     axis=2)
                    for head, hmap in enumerate(lmaps):
                        hmap = (hmap == np.amax(hmap, axis=1,
                                                keepdims=True)).astype(float)
                        hmap = np.sum(pallete * hmap, axis=1)
                        all_maps.append((hmap, "l{}_h{}".format(layer, head)))

                if verbose:
                    print("Saving layer maps...")
                for i in trange(rich_num):
                    misc.mkdir(
                        dnnlib.make_run_dir_path("eval/layer_maps/%06d" % i))

                for maps, name in tqdm(all_maps):
                    dirname = "eval/layer_maps{}/%06d/{}{}.png".format(
                        "" if step is None else ("/" + step), name)
                    save_images(maps, dirname, idx)

    # Produce interpolations between pairs or source latents
    # In the GANsformer case, varying one component at a time
    if "interpolations" in vis:
        ts = np.array(np.linspace(0.0, 1.0, num=intrp_density, endpoint=True))

        if verbose:
            print("Generating interpolations...")
        for i in trange(rich_num):
            misc.mkdir(
                dnnlib.make_run_dir_path("eval/interpolations-z/%06d" % i))
            misc.mkdir(
                dnnlib.make_run_dir_path("eval/interpolations-w/%06d" % i))

            z = np.random.randn(2, *G.input_shape[1:])
            z[0] = latents[i:i + 1]
            w = G.run(z,
                      labels,
                      randomize_noise=False,
                      return_dlatents=True,
                      minibatch_size=batch_size)[-1]

            def update(t, fn, ts, dim):
                if dim == 3:
                    ts = ts[:, np.newaxis]
                t_ups = []

                if intrp_per_component:
                    for c in range(components_num):
                        # copy over all the components except component c that will get interpolated
                        t_up = np.tile(
                            np.copy(t[0])[None], [intrp_density] + [1] * dim)
                        # interpolate component c
                        t_up[:, c] = fn(t[0, c], t[1, c], ts)
                        t_ups.append(t_up)

                    t_up = np.concatenate(t_ups, axis=0)
                else:
                    t_up = fn(t[0], t[1], ts)

                return t_up

            z_up = update(z, slerp, ts, 2)
            w_up = update(w, lerp, ts, 3)

            imgs1 = G.run(z_up,
                          labels,
                          randomize_noise=False,
                          minibatch_size=batch_size)[0]
            imgs2 = G.run(w_up,
                          labels,
                          randomize_noise=False,
                          minibatch_size=batch_size,
                          take_wlatents=True)[0]

            def save_interpolation(imgs, name):
                imgs = np.split(imgs, components_num, axis=0)
                for c in range(components_num):
                    filename = "eval/interpolations_%s/%06d/%02d" % (name, i,
                                                                     c)
                    imgs[c] = [
                        misc.to_pil(img, drange=drange_net) for img in imgs[c]
                    ]
                    imgs[c][-1].save(
                        dnnlib.make_run_dir_path("{}.png".format(filename)))
                    misc.save_gif(
                        imgs[c],
                        dnnlib.make_run_dir_path("{}.gif".format(filename)))

            save_interpolation(imgs1, "z")
            save_interpolation(imgs2, "w")

    # Compute noise variance map
    # Shows what areas vary the most given fixed source
    # latents due to the use of stochastic local noise
    if "noise_var" in vis:
        if verbose:
            print("Generating noise variance...")
        z = np.tile(np.random.randn(1, *G.input_shape[1:]),
                    [noise_samples_num, 1, 1])
        imgs = G.run(z, labels, minibatch_size=batch_size)[0]
        imgs = np.stack([misc.to_pil(img, drange=drange_net) for img in imgs],
                        axis=0)
        diff = np.std(np.mean(imgs, axis=3), axis=0) * 4
        diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)
        PIL.Image.fromarray(diff, "L").save(
            dnnlib.make_run_dir_path("eval/noise_variance.png"))

    # Compute style mixing table, varying using the latent A in some of the layers and latent B in rest.
    # For the GANsformer, also produce component mixes (using latents from A in some of the components,
    # and latents from B in the rest.
    if "style_mix" in vis:
        if verbose:
            print("Generating style mixes...")
        cols, rows = 4, 2
        row_lens = np.array([2, 5, 8, 11])

        # Create latent mixes
        mixes = {
            "layer": (np.arange(wlatents_all_layers.shape[2]) <
                      row_lens[:, None]).astype(np.float32)[:, None, None,
                                                            None, :, None],
            "component":
            (np.arange(wlatents_all_layers.shape[1]) <
             row_lens[:, None]).astype(np.float32)[:, None, None, :, None,
                                                   None]
        }
        ws = wlatents_all_layers[:cols + rows]
        orig_imgs = images[:cols + rows]
        col_ltnts = wlatents_all_layers[:cols][None, None]
        row_ltnts = wlatents_all_layers[cols:cols + rows][None, :, None]

        for name, mix in mixes.items():
            # Produce image mixes
            mix_ltnts = mix * row_ltnts + (1 - mix) * col_ltnts
            mix_ltnts = np.reshape(mix_ltnts,
                                   [-1, *wlatents_all_layers.shape[1:]])
            mix_imgs = G.run(mix_ltnts,
                             labels,
                             randomize_noise=False,
                             take_dlatents=True,
                             minibatch_size=batch_size)[0]
            mix_imgs = np.reshape(
                mix_imgs, [len(row_lens) * rows, cols, *mix_imgs.shape[1:]])

            # Create image table canvas
            H, W = mix_imgs.shape[-2:]
            canvas = PIL.Image.new("RGB", (W * (cols + 1), H *
                                           (len(row_lens) * rows + 1)),
                                   "black")

            # Place image mixes respectively at each position (row_idx, col_idx)
            for row_idx, row_elem in enumerate(
                [None] + list(range(len(row_lens) * rows))):
                for col_idx, col_elem in enumerate([None] + list(range(cols))):
                    if (row_elem, col_elem) == (None, None): continue
                    if row_elem is None: img = orig_imgs[col_elem]
                    elif col_elem is None:
                        img = orig_imgs[cols + (row_elem % rows)]
                    else:
                        img = mix_imgs[row_elem, col_elem]

                    canvas.paste(misc.to_pil(img, drange=drange_net),
                                 (W * col_idx, H * row_idx))

            canvas.save(
                dnnlib.make_run_dir_path("eval/{}_mixing.png".format(name)))

    if verbose:
        print(misc.bcolored("Visualizations Completed!", "blue"))
Example #13
0
def training_loop(
    # Configurations
    cG={},
    cD={},  # Generator and Discriminator command-line arguments
    dataset_args={},  # dataset.load_dataset() options
    sched_args={},  # train.TrainingSchedule options
    vis_args={},  # vis.eval options
    grid_args={},  # train.setup_snapshot_img_grid() options
    metric_arg_list=[],  # MetricGroup Options
    tf_config={},  # tflib.init_tf() options
    eval=False,  # Evaluation mode
    train=False,  # Training mode
    # Data
    data_dir=None,  # Directory to load datasets from
    total_kimg=25000,  # Total length of the training, measured in thousands of real images
    mirror_augment=False,  # Enable mirror augmentation?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks
    ratio=1.0,  # Image height/width ratio in the dataset
    # Optimization
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters
    lazy_regularization=True,  # Perform regularization as a separate training step?
    smoothing_kimg=10.0,  # Half-life of the running average of generator weights
    clip=None,  # Clip gradients threshold
    # Resumption
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning
    # Affects reporting and training schedule
    resume_time=0.0,  # Assumed wallclock time at the beginning, affects reporting
    recompile=False,  # Recompile network from source code (otherwise loads from snapshot)
    # Logging
    summarize=True,  # Create TensorBoard summaries
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    img_snapshot_ticks=3,  # How often to save image snapshots? None = disable
    network_snapshot_ticks=3,  # How often to save network snapshots? None = only save networks-final.pkl
    last_snapshots=10,  # Maximal number of prior snapshots to save
    eval_images_num=50000,  # Sample size for the metrics
    printname="",  # Experiment name for logging
    # Architecture
    merge=False):  # Generate several images and then merge them

    # Initialize dnnlib and TensorFlow
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus
    cG.name, cD.name = "g", "d"

    # Load dataset, configure training scheduler and metrics object
    dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                verbose=True,
                                **dataset_args)
    sched = training_schedule(sched_args,
                              cur_nimg=total_kimg * 1000,
                              dataset=dataset)
    metrics = metric_base.MetricGroup(metric_arg_list)

    # Construct or load networks
    with tf.device("/gpu:0"):
        no_op = tf.no_op()
        G, D, Gs = None, None, None
        if resume_pkl is None or recompile:
            misc.log("Constructing networks...", "white")
            G = tflib.Network("G",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cG.args)
            D = tflib.Network("D",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cD.args)
            Gs = G.clone("Gs")
        if resume_pkl is not None:
            G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile)

    G.print_layers()
    D.print_layers()

    # Train/Evaluate/Visualize
    # Labels are optional but not essential
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid(
        dataset, **grid_args)
    misc.save_img_grid(grid_reals,
                       dnnlib.make_run_dir_path("reals.png"),
                       drange=dataset.dynamic_range,
                       grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])

    if eval:
        # Save a snapshot of the current network to evaluate
        pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" %
                                       resume_kimg)
        misc.save_pkl((G, D, Gs), pkl, remove=False)

        # Quantitative evaluation
        metric = metrics.run(pkl,
                             num_imgs=eval_images_num,
                             run_dir=dnnlib.make_run_dir_path(),
                             data_dir=dnnlib.convert_path(data_dir),
                             num_gpus=num_gpus,
                             ratio=ratio,
                             tf_config=tf_config,
                             mirror_augment=mirror_augment)

        # Qualitative evaluation
        visualize.eval(G,
                       dataset,
                       batch_size=sched.minibatch_gpu,
                       drange_net=drange_net,
                       ratio=ratio,
                       **vis_args)

    if not train:
        dataset.close()
        exit()

    # Setup training inputs
    misc.log("Building TensorFlow graph...", "white")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[])
        lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[])
        step = tf.placeholder(tf.int32, name="step", shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name="minibatch_size_in",
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name="minibatch_gpu_in",
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                   tf.float32), smoothing_kimg *
                           1000.0) if smoothing_kimg > 0.0 else 0.0

    # Set optimizers
    for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]:
        set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip)

    # Build training graph for each GPU
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):

            # Create GPU-specific shadow copies of G and D
            for cN, N in [(cG, G), (cD, D)]:
                cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow")
            Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow")

            # Fetch training data via temporary variables
            with tf.name_scope("DataFetch"):
                reals, labels = dataset.get_minibatch_tf()
                reals = process_reals(reals, dataset.dynamic_range, drange_net,
                                      mirror_augment)
                reals, reals_fetch = read_data(
                    reals, "reals", [sched.minibatch_gpu] + dataset.shape,
                    minibatch_gpu_in)
                labels, labels_fetch = read_data(
                    labels, "labels",
                    [sched.minibatch_gpu, dataset.label_size],
                    minibatch_gpu_in)
                data_fetch_ops += [reals_fetch, labels_fetch]

            # Evaluate loss functions
            with tf.name_scope("G_loss"):
                cG.loss, cG.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    minibatch_size=minibatch_gpu_in,
                    **cG.loss_args)

            with tf.name_scope("D_loss"):
                cD.loss, cD.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    labels=labels,
                    minibatch_size=minibatch_gpu_in,
                    **cD.loss_args)

            for cN in [cG, cD]:
                set_optimizer_ops(cN, lazy_regularization, no_op)

    # Setup training ops
    data_fetch_op = tf.group(*data_fetch_ops)
    for cN in [cG, cD]:
        cN.train_op = cN.opt.apply_updates()
        cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta)

    # Finalize graph
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # Tensorboard summaries
    if summarize:
        misc.log("Initializing logs...", "white")
        summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
        if save_tf_graph:
            summary_log.add_graph(tf.get_default_graph())
        if save_weight_histograms:
            G.setup_weight_histograms()
            D.setup_weight_histograms()

    # Initialize training
    misc.log("Training for %d kimg..." % total_kimg, "white")
    dnnlib.RunContext.get().update("",
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()

    cur_tick, running_mb_counter = -1, 0
    cur_nimg = int(resume_kimg * 1000)
    tick_start_nimg = cur_nimg
    for cN in [cG, cD]:
        cN.lossvals_agg = {
            k: None
            for k in ["loss", "reg", "norm", "reg_norm"]
        }
        cN.opt.reset_optimizer_state()

    # Training loop
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops
        sched = training_schedule(sched_args,
                                  cur_nimg=cur_nimg,
                                  dataset=dataset)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        dataset.configure(sched.minibatch_gpu)

        # Run training ops
        feed_dict = {
            lrate_in_g: sched.G_lrate,
            lrate_in_d: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            step: sched.kimg
        }

        # Several iterations before updating training parameters
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            for cN in [cG, cD]:
                cN.run_reg = lazy_regularization and (running_mb_counter %
                                                      cN.reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            for cN in [cG, cD]:
                cN.lossvals = {
                    k: None
                    for k in ["loss", "reg", "norm", "reg_norm"]
                }

            # Gradient accumulation
            for _round in rounds:
                cG.lossvals.update(
                    tflib.run([cG.train_op, cG.ops], feed_dict)[1])
                if cG.run_reg:
                    _, cG.lossvals["reg_norm"] = tflib.run(
                        [cG.reg_op, cG.reg_norm], feed_dict)

                tflib.run(data_fetch_op, feed_dict)

                cD.lossvals.update(
                    tflib.run([cD.train_op, cD.ops], feed_dict)[1])
                if cD.run_reg:
                    _, cD.lossvals["reg_norm"] = tflib.run(
                        [cD.reg_op, cD.reg_norm], feed_dict)

            tflib.run([Gs_update_op], feed_dict)

            # Track loss statistics
            for cN in [cG, cD]:
                for k in cN.lossvals_agg:
                    cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k],
                                                cN.lossvals[k])

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress
            print(
                ("tick %s kimg %s   loss/reg: G (%s %s) D (%s %s)   grad norms: G (%s %s) D (%s %s)   "
                 + "time %s sec/kimg %s maxGPU %sGB %s") %
                (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)),
                 misc.bcolored(
                     "{:>8.1f}".format(
                         autosummary("Progress/kimg", cur_nimg / 1000.0)),
                     "red"),
                 misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)),
                 misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)),
                 misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.bold("%-10s" % dnnlib.util.format_time(
                     autosummary("Timing/total_sec", total_time))),
                 "{:>7.2f}".format(
                     autosummary("Timing/sec_per_kimg",
                                 tick_time / tick_kimg)),
                 "{:>4.1f}".format(
                     autosummary("Resources/peak_gpu_mem_gb",
                                 peak_gpu_mem_op.eval() / 2**30)), printname))

            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots
            if img_snapshot_ticks is not None and (
                    cur_tick % img_snapshot_ticks == 0 or done):
                visualize.eval(G,
                               dataset,
                               batch_size=sched.minibatch_gpu,
                               training=True,
                               step=cur_nimg // 1000,
                               grid_size=grid_size,
                               latents=grid_latents,
                               labels=grid_labels,
                               drange_net=drange_net,
                               ratio=ratio,
                               **vis_args)

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl, remove=False)

                if cur_tick % network_snapshot_ticks == 0 or done:
                    metric = metrics.run(
                        pkl,
                        num_imgs=eval_images_num,
                        run_dir=dnnlib.make_run_dir_path(),
                        data_dir=dnnlib.convert_path(data_dir),
                        num_gpus=num_gpus,
                        ratio=ratio,
                        tf_config=tf_config,
                        mirror_augment=mirror_augment)

                if last_snapshots > 0:
                    misc.rm(
                        sorted(
                            glob.glob(dnnlib.make_run_dir_path(
                                "network*.pkl")))[:-last_snapshots])

            # Update summaries and RunContext
            if summarize:
                metrics.update_autosummaries()
                tflib.autosummary.save_summaries(summary_log, cur_nimg)

            dnnlib.RunContext.get().update(None,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot
    misc.save_pkl((G, D, Gs),
                  dnnlib.make_run_dir_path("network-final.pkl"),
                  remove=False)

    # All done
    if summarize:
        summary_log.close()
    dataset.close()