def train_gan(separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400, lod_transition_kimg=400, total_kimg=10000, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, tick_kimg_default=50, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=4, network_snapshot_ticks=40, image_grid_type='default', resume_network_pkl=None, resume_kimg=0.0, resume_time=0.0): # Load dataset and build networks. training_set, drange_orig = load_dataset() if resume_network_pkl: print 'Resuming', resume_network_pkl G, D, _ = misc.load_pkl( os.path.join(config.result_dir, resume_network_pkl)) else: G = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.G) D = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.D) Gs = G.create_temporally_smoothed_version(beta=G_smoothing, explicit_updates=True) misc.print_network_topology_info(G.output_layers) misc.print_network_topology_info(D.output_layers) # Setup snapshot image grid. if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[3], G.output_shape[2] image_grid_size = np.clip(1920 / w, 3, 16), np.clip(1080 / h, 2, 16) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) elif image_grid_type == 'category': W = training_set.labels.shape[1] H = W if image_grid_size is None else image_grid_size[1] image_grid_size = W, H snapshot_fake_latents = random_latents(W * H, G.input_shape) snapshot_fake_labels = np.zeros((W * H, W), dtype=training_set.labels.dtype) example_real_images = np.zeros((W * H, ) + training_set.shape[1:], dtype=training_set.dtype) for x in xrange(W): snapshot_fake_labels[x::W, x] = 1.0 indices = np.arange( training_set.shape[0])[training_set.labels[:, x] != 0] for y in xrange(H): example_real_images[x + y * W] = training_set.h5_lods[0][ np.random.choice(indices)] else: raise ValueError('Invalid image_grid_type', image_grid_type) # Theano input variables and compile generation func. print 'Setting up Theano...' real_images_var = T.TensorType('float32', [False] * len(D.input_shape))('real_images_var') real_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('real_labels_var') fake_latents_var = T.TensorType('float32', [False] * len(G.input_shape))('fake_latents_var') fake_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('fake_labels_var') G_lrate = theano.shared(np.float32(0.0)) D_lrate = theano.shared(np.float32(0.0)) gen_fn = theano.function([fake_latents_var, fake_labels_var], Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), on_unused_input='ignore') # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 if config.D.get('mbdisc_kernels', None): print 'Initializing minibatch discrimination...' if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod)) D.eval(real_images_var, deterministic=False, init=True) init_layers = lasagne.layers.get_all_layers(D.output_layers) init_updates = [ update for layer in init_layers for update in getattr(layer, 'init_updates', []) ] init_fn = theano.function(inputs=[real_images_var], outputs=None, updates=init_updates) init_reals = training_set.get_random_minibatch(500, lod=initial_lod) init_reals = misc.adjust_dynamic_range(init_reals, drange_orig, drange_net) init_fn(init_reals) del init_reals # Save example images. snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size) # Training loop. cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / 1000.0) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max)) D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod)) if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod)) # Setup training func for current LOD. new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: print 'Compiling training funcs...' min_lod, max_lod = new_min_lod, new_max_lod # Pre-process reals. real_images_expr = real_images_var if dequantize_reals: rnd = theano.sandbox.rng_mrg.MRG_RandomStreams( lasagne.random.get_rng().randint(1, 2147462579)) epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32') real_images_expr = T.cast( real_images_expr, 'float32' ) + epsilon_noise # match original implementation of Improved Wasserstein real_images_expr = misc.adjust_dynamic_range( real_images_expr, drange_orig, drange_net) if min_lod > 0: # compensate for shrink_based_on_lod real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2) real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3) # Optimize loss. G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss( G, D, min_lod, max_lod, real_images_expr, real_labels_var, fake_latents_var, fake_labels_var, **config.loss) G_updates = adam(G_loss, G.trainable_params(), learning_rate=G_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_updates = adam(D_loss, D.trainable_params(), learning_rate=D_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() # Compile training funcs. if not separate_funcs: GD_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=G_updates + D_updates + Gs.updates, on_unused_input='ignore') else: D_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=D_updates, on_unused_input='ignore') G_train_fn = theano.function( [fake_latents_var, fake_labels_var], [], updates=G_updates + Gs.updates, on_unused_input='ignore') # Invoke training funcs. if not separate_funcs: assert D_training_repeats == 1 mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_train_out = GD_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) else: for idx in xrange(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_train_out = D_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) G_train_fn(random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) # Fade in D noise if we're close to becoming unstable fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): D.gdrop_strength.set_value(np.float32(gdrop_strength)) # Perform maintenance operations once per tick. if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Print progress. print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % ( (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg) # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) # Save network snapshot every N ticks. if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg / 1000))) # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) training_set.close() print 'Done.' with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'): pass
def train_gan( separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40 / speed_factor, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400 / speed_factor, lod_transition_kimg=400 / speed_factor, #lod_training_kimg = 40, #lod_transition_kimg = 40, total_kimg=10000 / speed_factor, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, #tick_kimg_default = 1, tick_kimg_default=50 / speed_factor, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=4, network_snapshot_ticks=40, image_grid_type='default', #resume_network_pkl = '006-celeb128-progressive-growing/network-snapshot-002009.pkl', resume_network_pkl=None, resume_kimg=0, resume_time=0.0): # Load dataset and build networks. training_set, drange_orig = load_dataset() print "*************** test the format of dataset ***************" print training_set print drange_orig # training_set是dataset模块解析h5之后的对象, # drange_orig 为training_set.get_dynamic_range() if resume_network_pkl: print 'Resuming', resume_network_pkl G, D, _ = misc.load_pkl( os.path.join(config.result_dir, resume_network_pkl)) else: G = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.G) D = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.D) Gs = G.create_temporally_smoothed_version(beta=G_smoothing, explicit_updates=True) # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造 print G print D #misc.print_network_topology_info(G.output_layers) #misc.print_network_topology_info(D.output_layers) # Setup snapshot image grid. # 设置中途输出图片的格式 if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[3], G.output_shape[2] image_grid_size = np.clip(1920 / w, 3, 16), np.clip(1080 / h, 2, 16) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) # Theano input variables and compile generation func. print 'Setting up Theano...' real_images_var = T.TensorType('float32', [False] * len(D.input_shape))('real_images_var') # <class 'theano.tensor.var.TensorVariable'> # print type(real_images_var),real_images_var real_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('real_labels_var') fake_latents_var = T.TensorType('float32', [False] * len(G.input_shape))('fake_latents_var') fake_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('fake_labels_var') # 带有_var的均为输入张量 G_lrate = theano.shared(np.float32(0.0)) D_lrate = theano.shared(np.float32(0.0)) # share语法就是用来设定默认值的,返回复制的对象 gen_fn = theano.function([fake_latents_var, fake_labels_var], Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), on_unused_input='ignore') # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var], # 输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), #生成函数 # Misc init. #读入当前分辨率 resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) #lod 精细度 initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 # Save example images. snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size) # Training loop. # 这里才是主训练入口 # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里 #现有图片数 cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time while cur_nimg < total_kimg * 1000: # Calculate current LOD. #计算当前精细度 cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) # 当前分辨率 minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. # 更新网络结构 lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max)) D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod)) if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod)) # Setup training func for current LOD. new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) #print " cur_lod%f\n min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod) if min_lod != new_min_lod or max_lod != new_max_lod: print 'Compiling training funcs...' min_lod, max_lod = new_min_lod, new_max_lod # Pre-process reals. real_images_expr = real_images_var if dequantize_reals: rnd = theano.sandbox.rng_mrg.MRG_RandomStreams( lasagne.random.get_rng().randint(1, 2147462579)) epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32') real_images_expr = T.cast( real_images_expr, 'float32' ) + epsilon_noise # match original implementation of Improved Wasserstein real_images_expr = misc.adjust_dynamic_range( real_images_expr, drange_orig, drange_net) if min_lod > 0: # compensate for shrink_based_on_lod real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2) real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3) # Optimize loss. G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss( G, D, min_lod, max_lod, real_images_expr, real_labels_var, fake_latents_var, fake_labels_var, **config.loss) G_updates = adam(G_loss, G.trainable_params(), learning_rate=G_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_updates = adam(D_loss, D.trainable_params(), learning_rate=D_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=D_updates, on_unused_input='ignore') G_train_fn = theano.function([fake_latents_var, fake_labels_var], [], updates=G_updates + Gs.updates, on_unused_input='ignore') for idx in xrange(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) print "******* test minibatch" print "mb_reals" print idx, D_training_repeats print mb_reals.shape, mb_labels.shape #print mb_reals print "mb_labels" #print mb_labels mb_train_out = D_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) G_train_fn(random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) # Fade in D noise if we're close to becoming unstable fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): D.gdrop_strength.set_value(np.float32(gdrop_strength)) # Perform maintenance operations once per tick. if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Print progress. print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % ( (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg) # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) # Save network snapshot every N ticks. if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg / 1000))) # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) training_set.close() print 'Done.' with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'): pass