def training_schedule(cur_nimg, resolution_log2, # current resolution_log2 num_gpus, lod_initial_resolution=4, lod_training_kimg=600, lod_transition_kimg=600, minibatch_base=4, minibatch_dict={}, max_minibatch_per_gpu = {}, G_lrate_base=0.001, G_lrate_dict={}, D_lrate_base=0.001, D_lrate_dict={}, lrate_rampup_kimg=0, tick_kimg_base=60, # note we only use 1/10 of the official implementation. My GPU is too slow... tick_kimg_dict={4: 60, 8:40, 16:20, 32:20, 64:20, 128:20, 256:20, 512:20, 1024:20} ): # dnnlib comes with EasyDict s = EasyDict() s.kimg = cur_nimg / 1000.0 phase_dur = lod_training_kimg + lod_transition_kimg phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 phase_kimg = s.kimg - phase_idx * phase_dur # Level-of-detail and resolution. s.lod = resolution_log2 s.lod -= np.floor(np.log2(lod_initial_resolution)) s.lod -= phase_idx if lod_transition_kimg > 0: s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg s.lod = max(s.lod, 0.0) s.resolution = 2 ** (resolution_log2 - int(np.floor(s.lod))) s.resolution_log2 = int(np.log2(s.resolution)) s.alpha = 1 - (s.lod - (resolution_log2 - s.resolution_log2)) assert 0 <= s.alpha <= 1.0 # Minibatch size. s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) s.minibatch -= s.minibatch % num_gpus if s.resolution in max_minibatch_per_gpu: s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) # Learning rate. s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) if lrate_rampup_kimg > 0: rampup = min(s.kimg / lrate_rampup_kimg, 1.0) s.G_lrate *= rampup s.D_lrate *= rampup # Other parameters. s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) return s
#desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True #desc += '-bedroom'; dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False #desc += '-car'; dataset = EasyDict(tfrecord_dir='lsun-car-512x384'); train.mirror_augment = False #desc += '-cat'; dataset = EasyDict(tfrecord_dir='lsun-cat-full'); train.mirror_augment = False # Number of GPUs. #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_size = 4 #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_size = 32 desc += '-4gpu' submit_config.num_gpus = 4 sched.minibatch_size = 16 #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_size = 32 # Default options. train.total_kimg = 25000 sched.G_lrate = 0.003 sched.D_lrate = sched.G_lrate # related to frequency of logs: sched.tick_kimg = 10 image_snapshot_ticks = 1 network_snapshot_ticks = 10 # debug ones: # sched.tick_kimg = 0.001 # image_snapshot_ticks = 1 # network_snapshot_ticks = 1 # WGAN-GP loss for CelebA-HQ. # desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)