def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. savenames=None, # random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. **kwargs ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. run_id = str(int(time.time()))[3:-2] hostname = str(socket.gethostname()) # Load training set. # if rank == 0: # print('Loading training set...') training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: # print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) ext = 'png' if training_set.image_shape[0] == 4 else 'jpg' # !!! # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): # !!! if os.path.isdir(resume_pkl): resume_pkl, resume_kimg = locate_latest_pkl(resume_pkl) print(' Resuming from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: # z = torch.empty([batch_gpu, G.z_dim], device=device) # c = torch.empty([batch_gpu, G.c_dim], device=device) # img = misc.print_module_summary(G, [z, c]) # misc.print_module_summary(D, [img, c]) pass # Setup augmentation. # if rank == 0: # print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. # if rank == 0: # print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. # if rank == 0: # print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1)] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1)] phases += [dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval)] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: # print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set) save_image_grid(images, os.path.join(run_dir, '_reals.' + ext), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) # Initialize logs. # if rank == 0: # print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') # Only upload data once, not for every process if rank == 0: url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg_d/' \ f'report/{run_id}/{hostname}' metadata = training_set_kwargs more_metadata = [ {'run_dir': run_dir}, {'data_loader_kwargs': data_loader_kwargs}, {'G_kwargs': G_kwargs}, {'D_kwargs': D_kwargs}, {'G_opt_kwargs': G_opt_kwargs}, {'D_opt_kwargs': D_opt_kwargs}, {'augment_kwargs': augment_kwargs}, {'loss_kwargs': loss_kwargs}, {'savenames': savenames}, {'random_seed': random_seed}, {'num_gpus': num_gpus}, {'batch_size': batch_size}, {'batch_gpu': batch_gpu}, {'ema_kimg': ema_kimg}, {'ema_rampup': ema_rampup}, {'G_reg_interval': G_reg_interval}, {'D_reg_interval': D_reg_interval}, {'augment_p': augment_p}, {'ada_target': ada_target}, {'ada_interval': ada_interval}, {'ada_kimg': ada_kimg}, {'total_kimg': total_kimg}, {'kimg_per_tick': kimg_per_tick}, {'image_snapshot_ticks': image_snapshot_ticks}, {'network_snapshot_ticks': network_snapshot_ticks}, {'resume_pkl': resume_pkl}, {'resume_kimg': resume_kimg}, {'cudnn_benchmark': cudnn_benchmark} ] for d in more_metadata: metadata.update(d) requests.post(url, data=json.dumps(training_set_kwargs)) stats_jsonl = None # stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() # maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / ( ada_kimg * 1000) augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_end_time = time.time() total_time = tick_end_time - start_time tick_time = tick_end_time - tick_start_time left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(tick_end_time + left_sec)) msg_final = ' %ss left till %s ' % (shortime(left_sec), finaltime[11:16]) # Print status line, accumulating the same information in stats_collector. # tick_end_time = time.time() fields = [] fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<4d}"] fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<6.1f}"] fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', total_time)):<8s}"] fields += [msg_final] fields += [f"min/tick {training_stats.report0('Timing/sec_per_tick', tick_time / 60):<6.3g}"] fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', tick_time / tick_kimg):<7.3g}"] fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2 ** 30):<4.1f}"] fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2 ** 30):<4.1f}"] torch.cuda.reset_peak_memory_stats() fields += [f"aug {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] # !!! fields += ["lr %.2g" % training_stats.report0('Progress/G_lrate', G_opt_kwargs.lr)] fields += ["%.2g" % training_stats.report0('Progress/D_lrate', D_opt_kwargs.lr)] training_stats.report0('Timing/total_hours', total_time / (60 * 60)) training_stats.report0('Timing/total_days', total_time / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() save_image_grid(images, os.path.join(run_dir, 'fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. # snapshot_pkl = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) compact_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_(False).cpu() snapshot_data[name] = module if name == 'G_ema': compact_data[name] = module # !!! del module # conserve memory # !!! pkl_snap = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[0], cur_nimg // 1000)) pkl_run = os.path.join(run_dir, '%s-%04d.pkl' % (savenames[1], cur_nimg // 1000)) if rank == 0: with open(pkl_snap, 'wb') as f: pickle.dump(snapshot_data, f) with open(pkl_run, 'wb') as f: pickle.dump(compact_data, f) # Evaluate metrics. ''' if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) ''' try: del snapshot_data # conserve memory except UnboundLocalError: pass # !!! try: del compact_data # conserve memory except UnboundLocalError: pass # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Report stats to database if rank == 0: url = f'https://mnr6yzqr22jgywm-adw2.adb.eu-frankfurt-1.oraclecloudapps.com/ords/thesisproject/sg2/' \ f'report/{run_id}/{hostname}' requests.post(url, data=json.dumps(stats_dict)) # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) json_data = json.dumps(fields) # Write stats to file stats_jsonl.write(json_data + '\n') stats_jsonl.flush() # if stats_tfevents is not None: # global_step = int(cur_nimg / 1e3) # walltime = timestamp - start_time # for name, value in stats_dict.items(): # stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) # for name, value in stats_metrics.items(): # stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) # stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() # maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')