Esempio n. 1
0
def print_nets(G, D, batch_gpu, device, log):
    if not log:
        return
    z = torch.empty([batch_gpu, *G.input_shape[1:]], device = device)
    c = torch.empty([batch_gpu, *G.cond_shape[1:]], device = device)
    img = torch_misc.print_module_summary(G, [z, c])[0]
    torch_misc.print_module_summary(D, [img, c])
def subprocess_fn(rank, args, temp_dir):
    dnnlib.util.Logger(should_flush=True)

    # Init torch.distributed.
    if args.num_gpus > 1:
        init_file = os.path.abspath(
            os.path.join(temp_dir, '.torch_distributed_init'))
        if os.name == 'nt':
            init_method = 'file:///' + init_file.replace('\\', '/')
            torch.distributed.init_process_group(backend='gloo',
                                                 init_method=init_method,
                                                 rank=rank,
                                                 world_size=args.num_gpus)
        else:
            init_method = f'file://{init_file}'
            torch.distributed.init_process_group(backend='nccl',
                                                 init_method=init_method,
                                                 rank=rank,
                                                 world_size=args.num_gpus)

    # Init torch_utils.
    sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    if rank != 0 or not args.verbose:
        custom_ops.verbosity = 'none'

    # Print network summary.
    device = torch.device('cuda', rank)
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
    if rank == 0 and args.verbose:
        z = torch.empty([1, G.z_dim], device=device)
        c = torch.empty([1, G.c_dim], device=device)
        misc.print_module_summary(G, [z, c])

    # Calculate each metric.
    for metric in args.metrics:
        if rank == 0 and args.verbose:
            print(f'Calculating {metric}...')
        progress = metric_utils.ProgressMonitor(verbose=args.verbose)
        result_dict = metric_main.calc_metric(
            metric=metric,
            G=G,
            dataset_kwargs=args.dataset_kwargs,
            num_gpus=args.num_gpus,
            rank=rank,
            device=device,
            progress=progress)
        if rank == 0:
            metric_main.report_metric(result_dict,
                                      run_dir=args.run_dir,
                                      snapshot_pkl=args.network_pkl)
        if rank == 0 and args.verbose:
            print()

    # Done.
    if rank == 0 and args.verbose:
        print('Exiting...')
Esempio n. 3
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        D2_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline. None = disable.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=None,  # EMA ramp-up coefficient.
        G_reg_interval=4,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        augment_p=0,  # Initial value of augmentation probability.
        ada_target=None,  # ADA target value. None = fixed p.
        ada_interval=4,  # How often to perform ADA adjustment?
        ada_kimg=500,  # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
        obake=None,  # Obake training: <bool>, default = False
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(
        **training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set,
                                                rank=rank,
                                                num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(
        torch.utils.data.DataLoader(dataset=training_set,
                                    sampler=training_set_sampler,
                                    batch_size=batch_size // num_gpus,
                                    **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()
    if obake is not None:
        D_mtcnn = MTCNN(image_size=D2_kwargs.mtcnn_output_size,
                        margin=D2_kwargs.mtcnn_output_margin,
                        thresholds=D2_kwargs.mtcnn_thresholds)
        D_face = InceptionResnetV1(pretrained=D2_kwargs.resnet_type).eval()

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name],
                                         module,
                                         require_all=False)

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([batch_gpu, G.z_dim], device=device)
        c = torch.empty([batch_gpu, G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Setup augmentation.
    if rank == 0:
        print('Setting up augmentation...')
    augment_pipe = None
    ada_stats = None
    if (augment_kwargs is not None) and (augment_p > 0
                                         or ada_target is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(
            **augment_kwargs).train().requires_grad_(False).to(
                device)  # subclass of torch.nn.Module
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        if ada_target is not None:
            ada_stats = training_stats.Collector(regex='Loss/signs/real')

    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    ddp_modules = dict()
    if obake is not None:
        for name, module in [('G_mapping', G.mapping),
                             ('G_synthesis', G.synthesis), ('D', D),
                             ('D_mtcnn', D_mtcnn), ('D_face', D_face),
                             (None, G_ema), ('augment_pipe', augment_pipe)]:
            if (num_gpus > 1) and (module is not None) and len(
                    list(module.parameters())) != 0:
                module.requires_grad_(True)
                module = torch.nn.parallel.DistributedDataParallel(
                    module, device_ids=[device], broadcast_buffers=False)
                module.requires_grad_(False)
            if name is not None:
                ddp_modules[name] = module
    else:
        for name, module in [('G_mapping', G.mapping),
                             ('G_synthesis', G.synthesis), ('D', D),
                             (None, G_ema), ('augment_pipe', augment_pipe)]:
            if (num_gpus > 1) and (module is not None) and len(
                    list(module.parameters())) != 0:
                module.requires_grad_(True)
                module = torch.nn.parallel.DistributedDataParallel(
                    module, device_ids=[device], broadcast_buffers=False)
                module.requires_grad_(False)
            if name is not None:
                ddp_modules[name] = module

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(
        device=device, **ddp_modules,
        **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [
        ('G', G, G_opt_kwargs, G_reg_interval),
        ('D', D, D_opt_kwargs, D_reg_interval)
    ]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(
                params=module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'both',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(
                module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'main',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
            phases += [
                dnnlib.EasyDict(name=name + 'reg',
                                module=module,
                                opt=opt,
                                interval=reg_interval)
            ]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(
            training_set=training_set)
        save_image_grid(images,
                        os.path.join(run_dir, 'reals.png'),
                        drange=[0, 255],
                        grid_size=grid_size)
        grid_z = torch.randn([labels.shape[0], G.z_dim],
                             device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        images = torch.cat([
            G_ema(z=z, c=c, noise_mode='const').cpu()
            for z, c in zip(grid_z, grid_c)
        ]).numpy()
        save_image_grid(images,
                        os.path.join(run_dir, 'fakes_init.png'),
                        drange=[-1, 1],
                        grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    while True:

        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (
                phase_real_img.to(device).to(torch.float32) / 127.5 -
                1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim],
                                    device=device)
            all_gen_z = [
                phase_gen_z.split(batch_gpu)
                for phase_gen_z in all_gen_z.split(batch_size)
            ]
            all_gen_c = [
                training_set.get_label(np.random.randint(len(training_set)))
                for _ in range(len(phases) * batch_size)
            ]
            all_gen_c = torch.from_numpy(
                np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [
                phase_gen_c.split(batch_gpu)
                for phase_gen_c in all_gen_c.split(batch_size)
            ]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z,
                                                   all_gen_c):
            if batch_idx % phase.interval != 0:
                continue

            # Initialize gradient accumulation.
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            # Accumulate gradients over multiple rounds.
            for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(
                    zip(phase_real_img, phase_real_c, phase_gen_z,
                        phase_gen_c)):
                sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
                gain = phase.interval
                print(phase.name)
                loss.accumulate_gradients(phase=phase.name,
                                          real_img=real_img,
                                          real_c=real_c,
                                          gen_z=gen_z,
                                          gen_c=gen_c,
                                          sync=sync,
                                          gain=gain)

            # Update weights.
            phase.module.requires_grad_(False)
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                for param in phase.module.parameters():
                    if param.grad is not None:
                        misc.nan_to_num(param.grad,
                                        nan=0,
                                        posinf=1e5,
                                        neginf=-1e5,
                                        out=param.grad)
                phase.opt.step()
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Execute ADA heuristic.
        if (ada_stats is not None) and (batch_idx % ada_interval == 0):
            ada_stats.update()
            adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (
                batch_size * ada_interval) / (ada_kimg * 1000)
            augment_pipe.p.copy_(
                (augment_pipe.p + adjust).max(misc.constant(0, device=device)))

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (
                cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in stats_collector.
        tick_end_time = time.time()
        fields = []
        fields += [
            f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"
        ]
        fields += [
            f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"
        ]
        fields += [
            f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
        ]
        fields += [
            f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
        ]
        fields += [
            f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"
        ]
        fields += [
            f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
        ]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
        ]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"
        ]
        torch.cuda.reset_peak_memory_stats()
        fields += [
            f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"
        ]
        training_stats.report0('Timing/total_hours',
                               (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days',
                               (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (
                done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([
                G_ema(z=z, c=c, noise_mode='const').cpu()
                for z, c in zip(grid_z, grid_c)
            ]).numpy()
            save_image_grid(images,
                            os.path.join(run_dir,
                                         f'fakes{cur_nimg//1000:06d}.png'),
                            drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema),
                                 ('augment_pipe', augment_pipe)]:
                if module is not None:
                    if num_gpus > 1:
                        misc.check_ddp_consistency(module,
                                                   ignore_regex=r'.*\.w_avg')
                    module = copy.deepcopy(module).eval().requires_grad_(
                        False).cpu()
                snapshot_data[name] = module
                del module  # conserve memory
            snapshot_pkl = os.path.join(
                run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
            if rank == 0:
                with open(snapshot_pkl, 'wb') as f:
                    pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(
                    metric=metric,
                    G=snapshot_data['G_ema'],
                    dataset_kwargs=training_set_kwargs,
                    num_gpus=num_gpus,
                    rank=rank,
                    device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict,
                                              run_dir=run_dir,
                                              snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        del snapshot_data  # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event
                                                    is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name,
                                          value.mean,
                                          global_step=global_step,
                                          walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}',
                                          value,
                                          global_step=global_step,
                                          walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=0.05,  # EMA ramp-up coefficient. None = no rampup.
        G_reg_interval=None,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        resume_kimg=0,  # First kimg to report when resuming training.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
        restart_every=-1,  # Time interval in seconds to exit code
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    torch.backends.cuda.matmul.allow_tf32 = False  # Improves numerical accuracy.
    torch.backends.cudnn.allow_tf32 = False  # Improves numerical accuracy.
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.
    __RESTART__ = torch.tensor(
        0., device=device)  # will be broadcasted to exit loop
    __CUR_NIMG__ = torch.tensor(resume_kimg * 1000,
                                dtype=torch.long,
                                device=device)
    __CUR_TICK__ = torch.tensor(0, dtype=torch.long, device=device)
    __BATCH_IDX__ = torch.tensor(0, dtype=torch.long, device=device)
    __PL_MEAN__ = torch.zeros([], device=device)
    best_fid = 9999

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(
        **training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set,
                                                rank=rank,
                                                num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(
        torch.utils.data.DataLoader(dataset=training_set,
                                    sampler=training_set_sampler,
                                    batch_size=batch_size // num_gpus,
                                    **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()

    # Check for existing checkpoint
    ckpt_pkl = None
    if restart_every > 0 and os.path.isfile(misc.get_ckpt_path(run_dir)):
        ckpt_pkl = resume_pkl = misc.get_ckpt_path(run_dir)

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name],
                                         module,
                                         require_all=False)

        if ckpt_pkl is not None:  # Load ticks
            __CUR_NIMG__ = resume_data['progress']['cur_nimg'].to(device)
            __CUR_TICK__ = resume_data['progress']['cur_tick'].to(device)
            __BATCH_IDX__ = resume_data['progress']['batch_idx'].to(device)
            __PL_MEAN__ = resume_data['progress'].get('pl_mean', torch.zeros(
                [])).to(device)
            best_fid = resume_data['progress'][
                'best_fid']  # only needed for rank == 0

        del resume_data

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([batch_gpu, G.z_dim], device=device)
        c = torch.empty([batch_gpu, G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    for module in [G, D, G_ema]:
        if module is not None and num_gpus > 1:
            for param in misc.params_and_buffers(module):
                torch.distributed.broadcast(param, src=0)

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(
        device=device, G=G, G_ema=G_ema, D=D,
        **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [
        ('G', G, G_opt_kwargs, G_reg_interval),
        ('D', D, D_opt_kwargs, D_reg_interval)
    ]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(
                params=module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'both',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(
                module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'main',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
            phases += [
                dnnlib.EasyDict(name=name + 'reg',
                                module=module,
                                opt=opt,
                                interval=reg_interval)
            ]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(
            training_set=training_set)
        save_image_grid(images,
                        os.path.join(run_dir, 'reals.png'),
                        drange=[0, 255],
                        grid_size=grid_size)

        grid_z = torch.randn([labels.shape[0], G.z_dim],
                             device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        images = torch.cat([
            G_ema(z=z, c=c, noise_mode='const').cpu()
            for z, c in zip(grid_z, grid_c)
        ]).numpy()

        save_image_grid(images,
                        os.path.join(run_dir, 'fakes_init.png'),
                        drange=[-1, 1],
                        grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    if num_gpus > 1:  # broadcast loaded states to all
        torch.distributed.broadcast(__CUR_NIMG__, 0)
        torch.distributed.broadcast(__CUR_TICK__, 0)
        torch.distributed.broadcast(__BATCH_IDX__, 0)
        torch.distributed.broadcast(__PL_MEAN__, 0)
        torch.distributed.barrier()  # ensure all processes received this info
    cur_nimg = __CUR_NIMG__.item()
    cur_tick = __CUR_TICK__.item()
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = __BATCH_IDX__.item()
    if progress_fn is not None:
        progress_fn(cur_nimg // 1000, total_kimg)
    if hasattr(loss, 'pl_mean'):
        loss.pl_mean.copy_(__PL_MEAN__)
    while True:

        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (
                phase_real_img.to(device).to(torch.float32) / 127.5 -
                1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim],
                                    device=device)
            all_gen_z = [
                phase_gen_z.split(batch_gpu)
                for phase_gen_z in all_gen_z.split(batch_size)
            ]
            all_gen_c = [
                training_set.get_label(np.random.randint(len(training_set)))
                for _ in range(len(phases) * batch_size)
            ]
            all_gen_c = torch.from_numpy(
                np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [
                phase_gen_c.split(batch_gpu)
                for phase_gen_c in all_gen_c.split(batch_size)
            ]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z,
                                                   all_gen_c):
            if batch_idx % phase.interval != 0:
                continue
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))

            # Accumulate gradients.
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            if phase.name in ['Dmain', 'Dboth', 'Dreg']:
                phase.module.feature_network.requires_grad_(False)

            for real_img, real_c, gen_z, gen_c in zip(phase_real_img,
                                                      phase_real_c,
                                                      phase_gen_z,
                                                      phase_gen_c):
                loss.accumulate_gradients(phase=phase.name,
                                          real_img=real_img,
                                          real_c=real_c,
                                          gen_z=gen_z,
                                          gen_c=gen_c,
                                          gain=phase.interval,
                                          cur_nimg=cur_nimg)
            phase.module.requires_grad_(False)

            # Update weights.
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                params = [
                    param for param in phase.module.parameters()
                    if param.grad is not None
                ]
                if len(params) > 0:
                    flat = torch.cat(
                        [param.grad.flatten() for param in params])
                    if num_gpus > 1:
                        torch.distributed.all_reduce(flat)
                        flat /= num_gpus
                    misc.nan_to_num(flat,
                                    nan=0,
                                    posinf=1e5,
                                    neginf=-1e5,
                                    out=flat)
                    grads = flat.split([param.numel() for param in params])
                    for param, grad in zip(params, grads):
                        param.grad = grad.reshape(param.shape)
                phase.opt.step()

            # Phase done.
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (
                cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [
            f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"
        ]
        fields += [
            f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"
        ]
        fields += [
            f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
        ]
        fields += [
            f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
        ]
        fields += [
            f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"
        ]
        fields += [
            f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
        ]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
        ]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"
        ]
        fields += [
            f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"
        ]
        torch.cuda.reset_peak_memory_stats()
        training_stats.report0('Timing/total_hours',
                               (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days',
                               (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Check for restart.
        if (rank == 0) and (restart_every > 0) and (time.time() - start_time >
                                                    restart_every):
            print('Restart job...')
            __RESTART__ = torch.tensor(1., device=device)
        if num_gpus > 1:
            torch.distributed.broadcast(__RESTART__, 0)
        if __RESTART__:
            done = True
            print(f'Process {rank} leaving...')
            if num_gpus > 1:
                torch.distributed.barrier()

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (
                done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([
                G_ema(z=z, c=c, noise_mode='const').cpu()
                for z, c in zip(grid_z, grid_c)
            ]).numpy()
            save_image_grid(images,
                            os.path.join(run_dir,
                                         f'fakes{cur_nimg//1000:06d}.png'),
                            drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(G=G,
                                 D=D,
                                 G_ema=G_ema,
                                 training_set_kwargs=dict(training_set_kwargs))
            for key, value in snapshot_data.items():
                if isinstance(value, torch.nn.Module):
                    snapshot_data[key] = value
                del value  # conserve memory

        # Save Checkpoint if needed
        if (rank == 0) and (restart_every > 0) and (
                network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_pkl = misc.get_ckpt_path(run_dir)
            # save as tensors to avoid error for multi GPU
            snapshot_data['progress'] = {
                'cur_nimg': torch.LongTensor([cur_nimg]),
                'cur_tick': torch.LongTensor([cur_tick]),
                'batch_idx': torch.LongTensor([batch_idx]),
                'best_fid': best_fid,
            }
            if hasattr(loss, 'pl_mean'):
                snapshot_data['progress']['pl_mean'] = loss.pl_mean.cpu()

            with open(snapshot_pkl, 'wb') as f:
                pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        # if (snapshot_data is not None) and (len(metrics) > 0):
        if cur_tick and (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(
                    metric=metric,
                    G=snapshot_data['G_ema'],
                    run_dir=run_dir,
                    cur_nimg=cur_nimg,
                    dataset_kwargs=training_set_kwargs,
                    num_gpus=num_gpus,
                    rank=rank,
                    device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict,
                                              run_dir=run_dir,
                                              snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)

            # save best fid ckpt
            snapshot_pkl = os.path.join(run_dir, f'best_model.pkl')
            cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt')
            if rank == 0:
                if 'fid50k_full' in stats_metrics and stats_metrics[
                        'fid50k_full'] < best_fid:
                    best_fid = stats_metrics['fid50k_full']

                    with open(snapshot_pkl, 'wb') as f:
                        dill.dump(snapshot_data, f)
                    # save curr iteration number (directly saving it to pkl leads to problems with multi GPU)
                    with open(cur_nimg_txt, 'w') as f:
                        f.write(str(cur_nimg))

        del snapshot_data  # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event is not None) and \
                    not (phase.start_event.cuda_event == 0 and phase.end_event.cuda_event == 0):            # Both events were not initialized yet, can happen with restart
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name,
                                          value.mean,
                                          global_step=global_step,
                                          walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}',
                                          value,
                                          global_step=global_step,
                                          walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')