def print_nets(G, D, batch_gpu, device, log): if not log: return z = torch.empty([batch_gpu, *G.input_shape[1:]], device = device) c = torch.empty([batch_gpu, *G.cond_shape[1:]], device = device) img = torch_misc.print_module_summary(G, [z, c])[0] torch_misc.print_module_summary(D, [img, c])
def subprocess_fn(rank, args, temp_dir): dnnlib.util.Logger(should_flush=True) # Init torch.distributed. if args.num_gpus > 1: init_file = os.path.abspath( os.path.join(temp_dir, '.torch_distributed_init')) if os.name == 'nt': init_method = 'file:///' + init_file.replace('\\', '/') torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) else: init_method = f'file://{init_file}' torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) # Init torch_utils. sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) if rank != 0 or not args.verbose: custom_ops.verbosity = 'none' # Print network summary. device = torch.device('cuda', rank) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) if rank == 0 and args.verbose: z = torch.empty([1, G.z_dim], device=device) c = torch.empty([1, G.c_dim], device=device) misc.print_module_summary(G, [z, c]) # Calculate each metric. for metric in args.metrics: if rank == 0 and args.verbose: print(f'Calculating {metric}...') progress = metric_utils.ProgressMonitor(verbose=args.verbose) result_dict = metric_main.calc_metric( metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) if rank == 0: metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) if rank == 0 and args.verbose: print() # Done. if rank == 0 and args.verbose: print('Exiting...')
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. D2_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. obake=None, # Obake training: <bool>, default = False ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() if obake is not None: D_mtcnn = MTCNN(image_size=D2_kwargs.mtcnn_output_size, margin=D2_kwargs.mtcnn_output_margin, thresholds=D2_kwargs.mtcnn_thresholds) D_face = InceptionResnetV1(pretrained=D2_kwargs.resnet_type).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name( **augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() if obake is not None: for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), ('D_mtcnn', D_mtcnn), ('D_face', D_face), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module else: for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval print(phase.name) loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * ( batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_( (augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in stats_collector. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() fields += [ f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}" ] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_( False).cpu() snapshot_data[name] = module del module # conserve memory snapshot_pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=0.05, # EMA ramp-up coefficient. None = no rampup. G_reg_interval=None, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. resume_kimg=0, # First kimg to report when resuming training. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. restart_every=-1, # Time interval in seconds to exit code ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. __RESTART__ = torch.tensor( 0., device=device) # will be broadcasted to exit loop __CUR_NIMG__ = torch.tensor(resume_kimg * 1000, dtype=torch.long, device=device) __CUR_TICK__ = torch.tensor(0, dtype=torch.long, device=device) __BATCH_IDX__ = torch.tensor(0, dtype=torch.long, device=device) __PL_MEAN__ = torch.zeros([], device=device) best_fid = 9999 # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Check for existing checkpoint ckpt_pkl = None if restart_every > 0 and os.path.isfile(misc.get_ckpt_path(run_dir)): ckpt_pkl = resume_pkl = misc.get_ckpt_path(run_dir) # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) if ckpt_pkl is not None: # Load ticks __CUR_NIMG__ = resume_data['progress']['cur_nimg'].to(device) __CUR_TICK__ = resume_data['progress']['cur_tick'].to(device) __BATCH_IDX__ = resume_data['progress']['batch_idx'].to(device) __PL_MEAN__ = resume_data['progress'].get('pl_mean', torch.zeros( [])).to(device) best_fid = resume_data['progress'][ 'best_fid'] # only needed for rank == 0 del resume_data # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') for module in [G, D, G_ema]: if module is not None and num_gpus > 1: for param in misc.params_and_buffers(module): torch.distributed.broadcast(param, src=0) # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, G=G, G_ema=G_ema, D=D, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() if num_gpus > 1: # broadcast loaded states to all torch.distributed.broadcast(__CUR_NIMG__, 0) torch.distributed.broadcast(__CUR_TICK__, 0) torch.distributed.broadcast(__BATCH_IDX__, 0) torch.distributed.broadcast(__PL_MEAN__, 0) torch.distributed.barrier() # ensure all processes received this info cur_nimg = __CUR_NIMG__.item() cur_tick = __CUR_TICK__.item() tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = __BATCH_IDX__.item() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) if hasattr(loss, 'pl_mean'): loss.pl_mean.copy_(__PL_MEAN__) while True: with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) # Accumulate gradients. phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) if phase.name in ['Dmain', 'Dboth', 'Dreg']: phase.module.feature_network.requires_grad_(False) for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c): loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg) phase.module.requires_grad_(False) # Update weights. with torch.autograd.profiler.record_function(phase.name + '_opt'): params = [ param for param in phase.module.parameters() if param.grad is not None ] if len(params) > 0: flat = torch.cat( [param.grad.flatten() for param in params]) if num_gpus > 1: torch.distributed.all_reduce(flat) flat /= num_gpus misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat) grads = flat.split([param.numel() for param in params]) for param, grad in zip(params, grads): param.grad = grad.reshape(param.shape) phase.opt.step() # Phase done. if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in training_stats. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] fields += [ f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Check for restart. if (rank == 0) and (restart_every > 0) and (time.time() - start_time > restart_every): print('Restart job...') __RESTART__ = torch.tensor(1., device=device) if num_gpus > 1: torch.distributed.broadcast(__RESTART__, 0) if __RESTART__: done = True print(f'Process {rank} leaving...') if num_gpus > 1: torch.distributed.barrier() # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(G=G, D=D, G_ema=G_ema, training_set_kwargs=dict(training_set_kwargs)) for key, value in snapshot_data.items(): if isinstance(value, torch.nn.Module): snapshot_data[key] = value del value # conserve memory # Save Checkpoint if needed if (rank == 0) and (restart_every > 0) and ( network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_pkl = misc.get_ckpt_path(run_dir) # save as tensors to avoid error for multi GPU snapshot_data['progress'] = { 'cur_nimg': torch.LongTensor([cur_nimg]), 'cur_tick': torch.LongTensor([cur_tick]), 'batch_idx': torch.LongTensor([batch_idx]), 'best_fid': best_fid, } if hasattr(loss, 'pl_mean'): snapshot_data['progress']['pl_mean'] = loss.pl_mean.cpu() with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. # if (snapshot_data is not None) and (len(metrics) > 0): if cur_tick and (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], run_dir=run_dir, cur_nimg=cur_nimg, dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) # save best fid ckpt snapshot_pkl = os.path.join(run_dir, f'best_model.pkl') cur_nimg_txt = os.path.join(run_dir, f'best_nimg.txt') if rank == 0: if 'fid50k_full' in stats_metrics and stats_metrics[ 'fid50k_full'] < best_fid: best_fid = stats_metrics['fid50k_full'] with open(snapshot_pkl, 'wb') as f: dill.dump(snapshot_data, f) # save curr iteration number (directly saving it to pkl leads to problems with multi GPU) with open(cur_nimg_txt, 'w') as f: f.write(str(cur_nimg)) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None) and \ not (phase.start_event.cuda_event == 0 and phase.end_event.cuda_event == 0): # Both events were not initialized yet, can happen with restart phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')