def setup_snapshot_image_grid(training_set, drange_net, grid_size=None, size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. # Select size. if grid_size is None: if size == '1080p': gw = np.clip(1920 // training_set.shape[2], 3, 32) gh = np.clip(1080 // training_set.shape[1], 2, 32) if size == '4k': gw = np.clip(3840 // training_set.shape[2], 7, 32) gh = np.clip(2160 // training_set.shape[1], 4, 32) else: gw = grid_size[0] gh = grid_size[1] # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=np.float32) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) for idx in range(gw * gh): x = idx % gw; y = idx // gw while True: real, label = training_set.get_minibatch_np(1) real = real.astype(np.float32) real = misc.adjust_dynamic_range(real, training_set.dynamic_range, drange_net) if layout == 'row_per_class' and training_set.label_size > 0: if label[0, y % training_set.label_size] == 0.0: continue reals[idx] = real[0] labels[idx] = label[0] break return (gw, gh), reals, labels
def process_reals(x, lod, mirror_augment, drange_data, drange_net): with tf.name_scope('ProcessReals'): with tf.name_scope('DynamicRange'): x = tf.cast(x, tf.float32) x = misc.adjust_dynamic_range(x, drange_data, drange_net) if mirror_augment: with tf.name_scope('MirrorAugment'): s = tf.shape(x) mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) mask = tf.tile(mask, [1, s[1], s[2], s[3]]) x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) with tf.name_scope( 'FadeLOD' ): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2] // 2, 2, s[3] // 2, 2]) y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) y = tf.tile(y, [1, 1, 1, 2, 1, 2]) y = tf.reshape(y, [-1, s[1], s[2], s[3]]) x = tfutil.lerp(x, y, lod - tf.floor(lod)) with tf.name_scope( 'UpscaleLOD' ): # Upscale to match the expected input/output size of the networks. s = tf.shape(x) factor = tf.cast(2**tf.floor(lod), tf.int32) x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) x = tf.tile(x, [1, 1, 1, factor, 1, factor]) x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) return x
def __init__(self, filename=None, model=None): if model is not None: # error return _, _, G = misc.load_pkl(filename) self.net = Net() self.net.G = G self.have_compiled = False self.net.labels_var = T.TensorType('float32', [False] * 512)('labels_var') # experiment num_example_latents = 10 self.net.example_latents = train.random_latents( num_example_latents, self.net.G.input_shape) self.net.example_labels = self.net.example_latents self.net.latents_var = T.TensorType( 'float32', [False] * len(self.net.example_latents.shape))('latents_var') self.net.labels_var = T.TensorType( 'float32', [False] * len(self.net.example_latents.shape))('labels_var') self.net.images_expr = self.net.G.eval(self.net.latents_var, self.net.labels_var, ignore_unused_inputs=True) self.net.images_expr = misc.adjust_dynamic_range( self.net.images_expr, [-1, 1], [0, 1]) train.imgapi_compile_gen_fn(self.net) self.invert_models = def_invert_models(self.net, layer='conv4', alpha=0.002)
def imgapi_load_net(run_id, snapshot=None, random_seed=1000, num_example_latents=1000, load_dataset=True, compile_gen_fn=True): class Net: pass net = Net() net.result_subdir = misc.locate_result_subdir(run_id) net.network_pkl = misc.locate_network_pkl(net.result_subdir, snapshot) _, _, net.G = misc.load_pkl(net.network_pkl) # Generate example latents and labels. np.random.seed(random_seed) net.example_latents = random_latents(num_example_latents, net.G.input_shape) net.example_labels = np.zeros((num_example_latents, 0), dtype=np.float32) net.dynamic_range = [0, 255] if load_dataset: imgapi_load_dataset(net) # Compile Theano func. net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var') net.labels_var = T.TensorType('float32', [False] * len(net.example_labels.shape)) ('labels_var') if hasattr(net.G, 'cur_lod'): net.lod = net.G.cur_lod.get_value() net.images_expr = net.G.eval(net.latents_var, net.labels_var, min_lod=net.lod, max_lod=net.lod, ignore_unused_inputs=True) else: net.lod = 0.0 net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True) net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], net.dynamic_range) if compile_gen_fn: imgapi_compile_gen_fn(net) return net
def process_reals(x, lod, mirror_augment, drange_data, drange_net): with tf.name_scope('ProcessReals'): with tf.name_scope('DynamicRange'): x = tf.cast(x, tf.float32) x = misc.adjust_dynamic_range(x, drange_data, drange_net) if mirror_augment: with tf.name_scope('MirrorAugment'): s = tf.shape(x) mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) mask = tf.tile(mask, [1, s[1], s[2], s[3]]) x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. s = tf.shape(x) y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) y = tf.tile(y, [1, 1, 1, 2, 1, 2]) y = tf.reshape(y, [-1, s[1], s[2], s[3]]) x = tfutil.lerp(x, y, lod - tf.floor(lod)) with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks. s = tf.shape(x) factor = tf.cast(2 ** tf.floor(lod), tf.int32) x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) x = tf.tile(x, [1, 1, 1, factor, 1, factor]) x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) return x
def train(self): while self.sche.cur_img < 1000 * self.total_kimg: real = self.dataloader.get_batch().cuda() real = propress_real(real, self.sche.lod, self.sche.phase) for repeat in range(self.D_repeat): loss_d = self.update_D(real, self.sche.batchsize) #self.Gs.update_Gs(self.G) self.sche.cur_img += self.sche.batchsize loss_g = self.update_G(self.sche.batchsize) self.sche.cur_img += self.sche.batchsize print("loss_d:%.2f loss_g:%.2f" % (loss_d, loss_g)) self.writer.add_scalar('Loss/loss_d', loss_d, self.sche.cur_img) self.writer.add_scalar('Loss/loss_g', loss_g, self.sche.cur_img) if self.sche.tick % self.sche.save_kimg_tick.get( self.sche.phase, 1) == 0: with torch.no_grad(): img = self.G(self.z, self.sche.lod, self.sche.phase, self.sche.TRANSITION) img = upscale_phase(img, self.sche.phase, self.sche.max_resolution) os.makedirs(os.path.join(self.tag_dir, "images"), exist_ok=True) save_image_grid( adjust_dynamic_range(img, [-1, 1], [0, 1]).clamp(0, 1), os.path.join( self.tag_dir, "images", "phase%d_cur%d.jpg" % (self.sche.phase, self.sche.last_cur_phase_kimg))) save_image_grid( adjust_dynamic_range( upscale_phase(real[:24], self.sche.phase, self.sche.max_resolution), [-1, 1], [0, 1]).clamp(0, 1), os.path.join( self.tag_dir, "images", "phase%d_cur%d_real.jpg" % (self.sche.phase, self.sche.last_cur_phase_kimg))) self.sche.update() if self.sche.NEW_PHASE is True: self.save_model(self.sche.phase) self.G.Add_To_RGB_Layer(self.sche.phase) self.D.Add_From_RGB_Layer(self.sche.phase) self.renew()
import theano import lasagne import dataset import network from theano import tensor as T import config import misc import numpy as np import scipy.ndimage _, _, G = misc.load_pkl("network-snapshot-009041.pkl") class Net: pass net = Net() net.G = G import train num_example_latents = 10 net.example_latents = train.random_latents(num_example_latents, net.G.input_shape) net.example_labels = net.example_latents net.latents_var = T.TensorType('float32', [False] * len(net.example_latents.shape))('latents_var') net.labels_var = T.TensorType('float32', [False] * len(net.example_latents.shape)) ('labels_var') print("HIYA", net.example_latents[:1].shape, net.example_labels[:1].shape) net.images_expr = net.G.eval(net.latents_var, net.labels_var, ignore_unused_inputs=True) net.images_expr = misc.adjust_dynamic_range(net.images_expr, [-1,1], [0,1]) train.imgapi_compile_gen_fn(net) images = net.gen_fn(net.example_latents[:1], net.example_labels[:1]) misc.save_image(images[0], "fake4c.png", drange=[0,1])
def train_gan(separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400, lod_transition_kimg=400, total_kimg=10000, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, tick_kimg_default=50, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=4, network_snapshot_ticks=40, image_grid_type='default', resume_network_pkl=None, resume_kimg=0.0, resume_time=0.0): # Load dataset and build networks. training_set, drange_orig = load_dataset() if resume_network_pkl: print 'Resuming', resume_network_pkl G, D, _ = misc.load_pkl( os.path.join(config.result_dir, resume_network_pkl)) else: G = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.G) D = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.D) Gs = G.create_temporally_smoothed_version(beta=G_smoothing, explicit_updates=True) misc.print_network_topology_info(G.output_layers) misc.print_network_topology_info(D.output_layers) # Setup snapshot image grid. if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[3], G.output_shape[2] image_grid_size = np.clip(1920 / w, 3, 16), np.clip(1080 / h, 2, 16) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) elif image_grid_type == 'category': W = training_set.labels.shape[1] H = W if image_grid_size is None else image_grid_size[1] image_grid_size = W, H snapshot_fake_latents = random_latents(W * H, G.input_shape) snapshot_fake_labels = np.zeros((W * H, W), dtype=training_set.labels.dtype) example_real_images = np.zeros((W * H, ) + training_set.shape[1:], dtype=training_set.dtype) for x in xrange(W): snapshot_fake_labels[x::W, x] = 1.0 indices = np.arange( training_set.shape[0])[training_set.labels[:, x] != 0] for y in xrange(H): example_real_images[x + y * W] = training_set.h5_lods[0][ np.random.choice(indices)] else: raise ValueError('Invalid image_grid_type', image_grid_type) # Theano input variables and compile generation func. print 'Setting up Theano...' real_images_var = T.TensorType('float32', [False] * len(D.input_shape))('real_images_var') real_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('real_labels_var') fake_latents_var = T.TensorType('float32', [False] * len(G.input_shape))('fake_latents_var') fake_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('fake_labels_var') G_lrate = theano.shared(np.float32(0.0)) D_lrate = theano.shared(np.float32(0.0)) gen_fn = theano.function([fake_latents_var, fake_labels_var], Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), on_unused_input='ignore') # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 if config.D.get('mbdisc_kernels', None): print 'Initializing minibatch discrimination...' if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod)) D.eval(real_images_var, deterministic=False, init=True) init_layers = lasagne.layers.get_all_layers(D.output_layers) init_updates = [ update for layer in init_layers for update in getattr(layer, 'init_updates', []) ] init_fn = theano.function(inputs=[real_images_var], outputs=None, updates=init_updates) init_reals = training_set.get_random_minibatch(500, lod=initial_lod) init_reals = misc.adjust_dynamic_range(init_reals, drange_orig, drange_net) init_fn(init_reals) del init_reals # Save example images. snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size) # Training loop. cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / 1000.0) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max)) D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod)) if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod)) # Setup training func for current LOD. new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: print 'Compiling training funcs...' min_lod, max_lod = new_min_lod, new_max_lod # Pre-process reals. real_images_expr = real_images_var if dequantize_reals: rnd = theano.sandbox.rng_mrg.MRG_RandomStreams( lasagne.random.get_rng().randint(1, 2147462579)) epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32') real_images_expr = T.cast( real_images_expr, 'float32' ) + epsilon_noise # match original implementation of Improved Wasserstein real_images_expr = misc.adjust_dynamic_range( real_images_expr, drange_orig, drange_net) if min_lod > 0: # compensate for shrink_based_on_lod real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2) real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3) # Optimize loss. G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss( G, D, min_lod, max_lod, real_images_expr, real_labels_var, fake_latents_var, fake_labels_var, **config.loss) G_updates = adam(G_loss, G.trainable_params(), learning_rate=G_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_updates = adam(D_loss, D.trainable_params(), learning_rate=D_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() # Compile training funcs. if not separate_funcs: GD_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=G_updates + D_updates + Gs.updates, on_unused_input='ignore') else: D_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=D_updates, on_unused_input='ignore') G_train_fn = theano.function( [fake_latents_var, fake_labels_var], [], updates=G_updates + Gs.updates, on_unused_input='ignore') # Invoke training funcs. if not separate_funcs: assert D_training_repeats == 1 mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_train_out = GD_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) else: for idx in xrange(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_train_out = D_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) G_train_fn(random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) # Fade in D noise if we're close to becoming unstable fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): D.gdrop_strength.set_value(np.float32(gdrop_strength)) # Perform maintenance operations once per tick. if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Print progress. print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % ( (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg) # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) # Save network snapshot every N ticks. if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg / 1000))) # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) training_set.close() print 'Done.' with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'): pass
def train_gan(separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40 / speed_factor, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400 / speed_factor, lod_transition_kimg=400 / speed_factor, total_kimg=10000 / speed_factor, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, tick_kimg_default=50 / speed_factor, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=1, network_snapshot_ticks=4, image_grid_type='default', resume_network=None, resume_kimg=0.0, resume_time=0.0): training_set, drange_orig = load_dataset() print("-" * 50) print("resume_kimg: %s" % resume_kimg) print("-" * 50) if resume_network: print("Resuming weight from:" + resume_network) G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) G, D = load_GD_weights(G, D, os.path.join(config.result_dir, resume_network), True) print("pre-trained weights loaded!") print("-" * 50) else: G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0, training_set.shape[1], training_set.shape[3]) print(G.summary()) print(D.summary()) # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 G_opt = optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon) D_opt = optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon) if config.loss['type'] == 'wass': G_loss = wasserstein_loss D_loss = wasserstein_loss elif config.loss['type'] == 'iwass': G_loss = multiple_loss D_loss = [mean_loss, 'mse'] D_loss_weight = [1.0, config.loss['iwass_lambda']] G.compile(G_opt, loss=G_loss) D.trainable = False G_train.compile(G_opt, loss=G_loss) D.trainable = True D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight) cur_nimg = int(resume_kimg * 1000) print("-" * 50) print("resume_kimg: %s" % resume_kimg) print("current nimg: %s" % cur_nimg) print("-" * 50) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[1], G.output_shape[2] print("w:%d,h:%d" % (w, h)) image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip( 1080 / h, 2, 16).astype('int') print("image_grid_size:", image_grid_size) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) print("example_real_images.shape:", example_real_images.shape) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) nimg_h = 0 while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) K.set_value(G.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) K.set_value(G_train.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) K.set_value(D_train.optimizer.lr, np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G_train, 'cur_lod'): K.set_value(G_train.cur_lod, np.float32(cur_lod)) if hasattr(D_train, 'cur_lod'): K.set_value(D_train.cur_lod, np.float32(cur_lod)) new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: min_lod, max_lod = new_min_lod, new_max_lod # train D d_loss = None for idx in range(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch_channel_last( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_latents = random_latents(minibatch_size, G.input_shape) mb_labels_rnd = random_labels(minibatch_size, training_set) if min_lod > 0: # compensate for shrink_based_on_lod mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1) mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2) mb_fakes = G.predict_on_batch([mb_latents]) epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1)) interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig, drange_net) d_loss, d_diff, d_norm = D_train.train_on_batch( [mb_fakes, mb_reals, interpolation], [ np.ones((minibatch_size, 1, 1, 1)), np.ones((minibatch_size, 1)) ]) d_score_real = D.predict_on_batch(mb_reals) d_score_fake = D.predict_on_batch(mb_fakes) print("%d real score: %d fake score: %d" % (idx, np.mean(d_score_real), np.mean(d_score_fake))) print("-" * 50) cur_nimg += minibatch_size #train G mb_latents = random_latents(minibatch_size, G.input_shape) mb_labels_rnd = random_labels(minibatch_size, training_set) g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones( (mb_latents.shape[0], 1, 1, 1))) print("%d images [D loss: %f] [G loss: %f]" % (cur_nimg, d_loss, g_loss)) print("-" * 50) fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): K.set_value(D.gdrop_strength, np.float32(gdrop_strength)) if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = G.predict_on_batch( snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: save_GD_weights( G, D, os.path.join(result_subdir, 'network-snapshot-%06d' % (cur_nimg / 1000))) save_GD(G, D, os.path.join(result_subdir, 'network-final')) training_set.close() print('Done.')
def process_colors(x): with tf.name_scope('ProcessColors'): x = misc.adjust_dynamic_range(x, [0, 1], [-1, 1]) return x
def process_masks(x): with tf.name_scope('ProcessMasks'): x = tf.cast(x, tf.float32) x = misc.adjust_dynamic_range(x, [0, 255], [-1, 1]) return x
def train_gan( separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40 / speed_factor, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400 / speed_factor, lod_transition_kimg=400 / speed_factor, total_kimg=10000 / speed_factor, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, tick_kimg_default=50 / speed_factor, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=1, network_snapshot_ticks=4, image_grid_type='default', resume_network='000-celeba/network-snapshot-000488', #resume_network = None, resume_kimg=511.6, resume_time=0.0): training_set, drange_orig = load_dataset() #print("training_set.shape:",training_set.shape) if resume_network: print("Resuming weight from:" + resume_network) G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) G, D = load_GD_weights(G, D, os.path.join(config.result_dir, resume_network), True) else: G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G) D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D) #missing Gs G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0, training_set.shape[1], training_set.shape[3]) #G_tmp = Model(inputs=[G.get_input_at(0)], # outputs = [G.get_layer('G6bPN').output]) print(G.summary()) print(D.summary()) #print(pg_GAN.summary()) # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 G_opt = optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon) D_opt = optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon) # GAN_opt = optimizers.Adam(lr = 0.0,beta_1 = 0.0,beta_2 = 0.99) if config.loss['type'] == 'wass': G_loss = wasserstein_loss D_loss = wasserstein_loss elif config.loss['type'] == 'iwass': G_loss = multiple_loss D_loss = [mean_loss, 'mse'] D_loss_weight = [1.0, config.loss['iwass_lambda']] G.compile(G_opt, loss=G_loss) D.trainable = False G_train.compile(G_opt, loss=G_loss) D.trainable = True D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight) cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time #real_image_input = Input((training_set.shape[2],training_set.shape[2],training_set.shape[-1]),name = "real_image_input") #real_label_input = Input((training_set.labels.shape[1]),name = "real_label_input") #fake_latent_input = Input((config.G['latent_size']),name = "fake_latent_input") #fake_labels_input = Input((training_set.labels.shape[1]),name = "fake_label_input") if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[1], G.output_shape[2] print("w:%d,h:%d" % (w, h)) image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip( 1080 / h, 2, 16).astype('int') print("image_grid_size:", image_grid_size) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) #print("image_grid_size:",image_grid_size) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) #snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) #result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) #('example_real_images.shape:', (120, 3, 128, 128)) #print("example_real_images:",example_real_images) print("example_real_images.shape:", example_real_images.shape) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) #misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size) #NINweight = G.get_layer('Glod0NIN').get_weights()[0] #print("Glod0NIN weight:",NINweight) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) nimg_h = 0 while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) #G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max)) K.set_value(G.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) K.set_value(G_train.optimizer.lr, np.float32(lrate_coef * G_learning_rate_max)) #D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max)) K.set_value(D_train.optimizer.lr, np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G_train, 'cur_lod'): K.set_value(G_train.cur_lod, np.float32(cur_lod)) if hasattr(D_train, 'cur_lod'): K.set_value(D_train.cur_lod, np.float32(cur_lod)) new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: min_lod, max_lod = new_min_lod, new_max_lod # # Pre-process reals. # real_images_expr = real_images_var # if dequantize_reals: # epsilon_noise = K.random_uniform_variable(real_image_input.shape(), low=-0.5, high=0.5, dtype='float32', seed=np.random.randint(1, 2147462579)) # epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32') # real_images_expr = T.cast(real_images_expr, 'float32') + epsilon_noise # match original implementation of Improved Wasserstein # real_images_expr = misc.adjust_dynamic_range(real_images_expr, drange_orig, drange_net) # if min_lod > 0: # compensate for shrink_based_on_lod # real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2) # real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3) # train D d_loss = None for idx in range(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch_channel_last( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_latents = random_latents(minibatch_size, G.input_shape) mb_labels_rnd = random_labels(minibatch_size, training_set) if min_lod > 0: # compensate for shrink_based_on_lod mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1) mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2) mb_fakes = G.predict_on_batch([mb_latents]) epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1)) interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig, drange_net) d_loss, d_diff, d_norm = D_train.train_on_batch( [mb_fakes, mb_reals, interpolation], [ np.ones((minibatch_size, 1, 1, 1)), np.ones((minibatch_size, 1)) ]) d_score_real = D.predict_on_batch(mb_reals) d_score_fake = D.predict_on_batch(mb_fakes) print("real score: %d fake score: %d" % (np.mean(d_score_real), np.mean(d_score_fake))) #d_loss_real = D.train_on_batch(mb_reals, -np.ones((mb_reals.shape[0],1,1,1))) #d_loss_fake = D.train_on_batch(mb_fakes, np.ones((mb_fakes.shape[0],1,1,1))) #d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) cur_nimg += minibatch_size #train G mb_latents = random_latents(minibatch_size, G.input_shape) mb_labels_rnd = random_labels(minibatch_size, training_set) #if cur_nimg//100 !=nimg_h: # snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) # print(np.mean(snapshot_fake_images,axis=-1)) # G_lod = G_tmp.predict(snapshot_fake_latents) # print("G6bPN:",np.mean(G_lod,axis=-1)) # misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d_beforeGtrain.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size) g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones( (mb_latents.shape[0], 1, 1, 1))) print("%d [D loss: %f] [G loss: %f]" % (cur_nimg, d_loss, g_loss)) #print(cur_nimg) #print(g_loss) #print(d_loss) # Fade in D noise if we're close to becoming unstable #if cur_nimg//100 !=nimg_h: # nimg_h=cur_nimg//100 # snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) # print(np.mean(snapshot_fake_images,axis=-1)) # G_lod = G_tmp.predict(snapshot_fake_latents) # print("G6bPN:",np.mean(G_lod,axis=-1)) # NINweight = G.get_layer('Glod0NIN').get_weights()[0] # print(NINweight) # misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size) fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): K.set_value(D.gdrop_strength, np.float32(gdrop_strength)) if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] ''' # Print progress. print ('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % ( (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg)) ''' # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = G.predict_on_batch( snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: save_GD_weights( G, D, os.path.join(result_subdir, 'network-snapshot-%06d' % (cur_nimg / 1000))) save_GD(G, D, os.path.join(result_subdir, 'network-final')) training_set.close() print('Done.')
def pre_process( imgs, # Images to pre-process coords, # Coords corresponding bboxes, # Bounding boxes #threshold, # drange_imgs, # Dynamic range for the images (Typically: [0, 255]) drange_coords, # Dynamic range for the coordinates (Typically: [0, image.shape[0]]) drange_net, # Dynamic range for the network (Typically: [-1, 1]) mirror_augment=False, # Should mirror augment be applied? random_dw_conv=False, # Apply a random depthwise convolution to this input image? horizontal_flip=False, ): with tf.name_scope('ProcessReals'): imgs = tf.cast(imgs, tf.float32) coords = tf.cast(coords, tf.float32) with tf.name_scope('DynamicRange'): imgs = misc.adjust_dynamic_range(imgs, drange_imgs, drange_net) coords = misc.adjust_dynamic_range(coords, drange_coords, drange_net) if mirror_augment: with tf.name_scope('MirrorAugment'): s = tf.shape(imgs) mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) mask = tf.tile(mask, [1, s[1], s[2], s[3]]) imgs = tf.where(mask < 0.5, imgs, tf.reverse(imgs, axis=[3])) if random_dw_conv: with tf.name_scope('RandomDWConv'): # Parameters of the augmentation: m = 0.5 filt = 2 * m * tf.random_uniform((3, 3, 3, 1)) - m imgs = (tf.nn.depthwise_conv2d(imgs, filt, strides=[1, 1, 1, 1], padding='SAME', data_format='NCHW') + imgs) / (1 + m) if horizontal_flip: with tf.name_scope('HorizontalFlip'): rand = np.random.random() if rand > 0.5: imgs = tf.image.flip_left_right(imgs) mult = tf.constant([-1, 1, -1, 1, -1, 1]) coords *= mult # net_bboxes = tf.constant(generate_output(64, np.arange(4,17,2)), tf.float32) # TODO: architecture independance with tf.name_scope('ProcessBboxes'): bboxes = tf.cast(bboxes, tf.float32) bboxes_transformed = tf.map_fn( lambda bbox: bboxes_fn(bbox, net_bboxes, config.threshold), bboxes) refinement = tf.map_fn(lambda bbox: bbox_refinement(bbox, net_bboxes), bboxes) # TODO: optimize refinement # Should refinement be normalized? # net_bboxes.shape = [561, 4] # refinement.shape = [?, 561, 4] # bboxes = tf.map_fn(lambda bbox: IoU(bbox, net_bboxes_area, config.threshold), bboxes) # Remove the irrelevent boxes # mask = tf.greater(tf.math.reduce_sum(bboxes,axis=-1),0) # bboxes = tf.boolean_mask(bboxes, mask) # imgs = tf.boolean_mask(imgs,mask) # coords = tf.boolean_mask(coords,mask) return imgs, coords, bboxes_transformed, refinement
def train_detector( minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=1, # Total length of the training, measured in thousands of real images. drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. snapshot_size=16, # Size of the snapshot image snapshot_ticks=2**13, # Number of images before maintenance image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=1, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. maintenance_start_time = time.time() # Load the datasets training_set = dataset.load_dataset(tfrecord=config.tfrecord_train, verbose=True, **config.dataset) testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test, verbose=True, repeat=False, shuffle_mb=0, **config.dataset) testing_set_len = len(testing_set) # TODO: data augmentation # TODO: testing set # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: # TODO: save methods network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) N = misc.load_pkl(network_pkl) else: print('Constructing the network...' ) # TODO: better network (like lod-wise network) N = tfutil.Network('N', num_channels=training_set.shape[0], resolution=training_set.shape[1], **config.N) N.print_layers() print('Building TensorFlow graph...') # Training set up with tf.name_scope('Inputs'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) # minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) reals, labels, bboxes = training_set.get_minibatch_tf( ) # TODO: increase the size of the batch by several loss computation and mean N_opt = tfutil.Optimizer(name='TrainN', learning_rate=lrate_in, **config.N_opt) with tf.device('/gpu:0'): reals, labels, gt_outputs, gt_ref = pre_process( reals, labels, bboxes, training_set.dynamic_range, [0, training_set.shape[-2]], drange_net) with tf.name_scope('N_loss'): # TODO: loss inadapted N_loss = tfutil.call_func_by_name(N=N, reals=reals, gt_outputs=gt_outputs, gt_ref=gt_ref, **config.N_loss) N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables) N_train_op = N_opt.apply_updates() # Testing set up with tf.device('/gpu:0'): test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf( ) test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process( test_reals_tf, test_labels_tf, test_bboxes_tf, testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net) with tf.name_scope('N_test_loss'): test_loss = tfutil.call_func_by_name(N=N, reals=test_reals_tf, gt_outputs=test_gt_outputs_tf, gt_ref=test_gt_ref_tf, is_training=False, **config.N_loss) print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: N.setup_weight_histograms() test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size) misc.save_img_bboxes(test_reals, test_bboxes, os.path.join(result_subdir, 'reals.png'), snapshot_size, adjust_range=False) test_reals = misc.adjust_dynamic_range(test_reals, training_set.dynamic_range, drange_net) test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size) misc.save_img_bboxes(test_reals, test_preds, os.path.join(result_subdir, 'fakes.png'), snapshot_size) print('Training...') if resume_run_id is None: tfutil.run(tf.global_variables_initializer()) cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch) _train_loss = 0 while cur_nimg < total_kimg * 1000: # Run training ops. # for _ in range(minibatch_repeats): _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate}) _train_loss += loss cur_nimg += sched.minibatch # Perform maintenance tasks once per tick. if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0 and cur_nimg > 0): cur_tick += 1 cur_time = time.time() # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 _train_loss = _train_loss / (cur_nimg - tick_start_nimg) tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time testing_set.configure(sched.minibatch) _test_loss = 0 # testing_set_len = 1 # TMP for _ in range(0, testing_set_len, sched.minibatch): _test_loss += tfutil.run(test_loss) _test_loss /= testing_set_len # Report progress. # TODO: improved report display print( 'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/maintenance', maintenance_time), tfutil.autosummary('TrainN/train_loss', _train_loss), tfutil.autosummary('TrainN/test_loss', _test_loss))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) if cur_tick % image_snapshot_ticks == 0: test_bboxes, test_refs = N.run(test_reals, minibatch_size=snapshot_size) misc.save_img_bboxes_ref( test_reals, test_bboxes, test_refs, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), snapshot_size) if cur_tick % network_snapshot_ticks == 0: misc.save_pkl( N, os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) _train_loss = 0 # Record start time of the next tick. tick_start_time = time.time() # Write final results. # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_gan( separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40/speed_factor, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400/speed_factor, lod_transition_kimg=400/speed_factor, total_kimg=10000/speed_factor, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, tick_kimg_default=50/speed_factor, tick_kimg_overrides={32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1}, image_snapshot_ticks=1, network_snapshot_ticks=4, image_grid_type='default', # resume_network = '000-celeba/network-snapshot-000488', resume_network=None, resume_kimg=0.0, resume_time=0.0): training_set, drange_orig = load_dataset() G, G_train,\ D, D_train = build_trainable_model(resume_network, training_set.shape[3], training_set.shape[1], training_set.labels.shape[1], adam_beta1, adam_beta2, adam_epsilon) # Misc init. resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) initial_lod = max(resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[1], G.output_shape[2] print('w:%d,h:%d' % (w, h)) image_grid_size = (np.clip(int(1920 // w), 3, 16).astype('int'), np.clip(1080 / h, 2, 16).astype('int')) print('image_grid_size:', image_grid_size) example_real_images, snapshot_fake_labels = \ training_set.get_random_minibatch_channel_last( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents( np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) result_subdir = misc.create_result_subdir( config.result_dir, config.run_desc) result_subdir = Path(result_subdir) print('example_real_images.shape:', example_real_images.shape) misc.save_image_grid(example_real_images, str(result_subdir / 'reals.png'), drange=drange_orig, grid_size=image_grid_size) snapshot_fake_latents = random_latents( np.prod(image_grid_size), G.input_shape) snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, str(result_subdir/f'fakes{cur_nimg//1000:06}.png'), drange=drange_viz, grid_size=image_grid_size) # nimg_h = 0 while cur_nimg < total_kimg * 1000: # Calculate current LOD. cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0/speed_factor)) / \ (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max(1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2 ** (resolution_log2 - int(np.floor(cur_lod))) minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get( cur_res, tick_kimg_default) # Update network config. lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) K.set_value(G.optimizer.lr, np.float32( lrate_coef * G_learning_rate_max)) K.set_value(G_train.optimizer.lr, np.float32( lrate_coef * G_learning_rate_max)) K.set_value(D_train.optimizer.lr, np.float32( lrate_coef * D_learning_rate_max)) if hasattr(G_train, 'cur_lod'): K.set_value(G_train.cur_lod, np.float32(cur_lod)) if hasattr(D_train, 'cur_lod'): K.set_value(D_train.cur_lod, np.float32(cur_lod)) new_min_lod, new_max_lod = int( np.floor(cur_lod)), int(np.ceil(cur_lod)) if min_lod != new_min_lod or max_lod != new_max_lod: min_lod, max_lod = new_min_lod, new_max_lod # train D d_loss = None for idx in range(D_training_repeats): mb_reals, mb_labels = \ training_set.get_random_minibatch_channel_last( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) mb_latents = random_latents(minibatch_size, G.input_shape) # compensate for shrink_based_on_lod if min_lod > 0: mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1) mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2) mb_fakes = G.predict_on_batch([mb_latents]) epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1)) interpolation = epsilon*mb_reals + (1-epsilon)*mb_fakes mb_reals = misc.adjust_dynamic_range( mb_reals, drange_orig, drange_net) d_loss, d_diff, d_norm = \ D_train.train_on_batch([mb_fakes, mb_reals, interpolation], [np.ones((minibatch_size, 1, 1, 1)), np.ones((minibatch_size, 1))]) d_score_real = D.predict_on_batch(mb_reals) d_score_fake = D.predict_on_batch(mb_fakes) print('real score: %d fake score: %d' % (np.mean(d_score_real), np.mean(d_score_fake))) cur_nimg += minibatch_size # train G mb_latents = random_latents(minibatch_size, G.input_shape) g_loss = G_train.train_on_batch( [mb_latents], (-1)*np.ones((mb_latents.shape[0], 1, 1, 1))) print('%d [D loss: %f] [G loss: %f]' % (cur_nimg, d_loss, g_loss)) fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + \ fake_score_cur * (1.0 - gdrop_beta) gdrop_strength = gdrop_coef * \ (max(fake_score_avg - gdrop_lim, 0.0) ** gdrop_exp) if hasattr(D, 'gdrop_strength'): K.set_value(D.gdrop_strength, np.float32(gdrop_strength)) is_complete = cur_nimg >= total_kimg * 1000 is_generate_a_lot = \ cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 if is_generate_a_lot or is_complete: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time print(f'tick time: {tick_time}') print(f'tick image: {tick_kimg}k') tick_start_time = cur_time # Visualize generated images. is_image_snapshot_ticks = cur_tick % image_snapshot_ticks == 0 if is_image_snapshot_ticks or is_complete: snapshot_fake_images = G.predict_on_batch( snapshot_fake_latents) misc.save_image_grid(snapshot_fake_images, str(result_subdir / f'fakes{cur_nimg // 1000:06}.png'), drange=drange_viz, grid_size=image_grid_size) if cur_tick % network_snapshot_ticks == 0 or is_complete: save_GD_weights(G, D, str( result_subdir / f'network-snapshot-{cur_nimg // 1000:06}')) break save_GD(G, D, str(result_subdir/'network-final')) training_set.close() print('Done.') train_complete_time = time.time()-train_start_time print(f'training time: {train_complete_time}')
def train_gan( separate_funcs=False, D_training_repeats=1, G_learning_rate_max=0.0010, D_learning_rate_max=0.0010, G_smoothing=0.999, adam_beta1=0.0, adam_beta2=0.99, adam_epsilon=1e-8, minibatch_default=16, minibatch_overrides={}, rampup_kimg=40 / speed_factor, rampdown_kimg=0, lod_initial_resolution=4, lod_training_kimg=400 / speed_factor, lod_transition_kimg=400 / speed_factor, #lod_training_kimg = 40, #lod_transition_kimg = 40, total_kimg=10000 / speed_factor, dequantize_reals=False, gdrop_beta=0.9, gdrop_lim=0.5, gdrop_coef=0.2, gdrop_exp=2.0, drange_net=[-1, 1], drange_viz=[-1, 1], image_grid_size=None, #tick_kimg_default = 1, tick_kimg_default=50 / speed_factor, tick_kimg_overrides={ 32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1 }, image_snapshot_ticks=4, network_snapshot_ticks=40, image_grid_type='default', #resume_network_pkl = '006-celeb128-progressive-growing/network-snapshot-002009.pkl', resume_network_pkl=None, resume_kimg=0, resume_time=0.0): # Load dataset and build networks. training_set, drange_orig = load_dataset() print "*************** test the format of dataset ***************" print training_set print drange_orig # training_set是dataset模块解析h5之后的对象, # drange_orig 为training_set.get_dynamic_range() if resume_network_pkl: print 'Resuming', resume_network_pkl G, D, _ = misc.load_pkl( os.path.join(config.result_dir, resume_network_pkl)) else: G = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.G) D = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.D) Gs = G.create_temporally_smoothed_version(beta=G_smoothing, explicit_updates=True) # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造 print G print D #misc.print_network_topology_info(G.output_layers) #misc.print_network_topology_info(D.output_layers) # Setup snapshot image grid. # 设置中途输出图片的格式 if image_grid_type == 'default': if image_grid_size is None: w, h = G.output_shape[3], G.output_shape[2] image_grid_size = np.clip(1920 / w, 3, 16), np.clip(1080 / h, 2, 16) example_real_images, snapshot_fake_labels = training_set.get_random_minibatch( np.prod(image_grid_size), labels=True) snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape) else: raise ValueError('Invalid image_grid_type', image_grid_type) # Theano input variables and compile generation func. print 'Setting up Theano...' real_images_var = T.TensorType('float32', [False] * len(D.input_shape))('real_images_var') # <class 'theano.tensor.var.TensorVariable'> # print type(real_images_var),real_images_var real_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('real_labels_var') fake_latents_var = T.TensorType('float32', [False] * len(G.input_shape))('fake_latents_var') fake_labels_var = T.TensorType( 'float32', [False] * len(training_set.labels.shape))('fake_labels_var') # 带有_var的均为输入张量 G_lrate = theano.shared(np.float32(0.0)) D_lrate = theano.shared(np.float32(0.0)) # share语法就是用来设定默认值的,返回复制的对象 gen_fn = theano.function([fake_latents_var, fake_labels_var], Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), on_unused_input='ignore') # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var], # 输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True), #生成函数 # Misc init. #读入当前分辨率 resolution_log2 = int(np.round(np.log2(G.output_shape[2]))) #lod 精细度 initial_lod = max( resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0) cur_lod = 0.0 min_lod, max_lod = -1.0, -2.0 fake_score_avg = 0.0 # Save example images. snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc) misc.save_image_grid(example_real_images, os.path.join(result_subdir, 'reals.png'), drange=drange_orig, grid_size=image_grid_size) misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size) # Training loop. # 这里才是主训练入口 # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里 #现有图片数 cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() tick_train_out = [] train_start_time = tick_start_time - resume_time while cur_nimg < total_kimg * 1000: # Calculate current LOD. #计算当前精细度 cur_lod = initial_lod if lod_training_kimg or lod_transition_kimg: tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg + lod_transition_kimg) cur_lod -= np.floor(tlod) if lod_transition_kimg: cur_lod -= max( 1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0) cur_lod = max(cur_lod, 0.0) # Look up resolution-dependent parameters. cur_res = 2**(resolution_log2 - int(np.floor(cur_lod))) # 当前分辨率 minibatch_size = minibatch_overrides.get(cur_res, minibatch_default) tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default) # Update network config. # 更新网络结构 lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg) lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg) G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max)) D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max)) if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod)) if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod)) # Setup training func for current LOD. new_min_lod, new_max_lod = int(np.floor(cur_lod)), int( np.ceil(cur_lod)) #print " cur_lod%f\n min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod) if min_lod != new_min_lod or max_lod != new_max_lod: print 'Compiling training funcs...' min_lod, max_lod = new_min_lod, new_max_lod # Pre-process reals. real_images_expr = real_images_var if dequantize_reals: rnd = theano.sandbox.rng_mrg.MRG_RandomStreams( lasagne.random.get_rng().randint(1, 2147462579)) epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32') real_images_expr = T.cast( real_images_expr, 'float32' ) + epsilon_noise # match original implementation of Improved Wasserstein real_images_expr = misc.adjust_dynamic_range( real_images_expr, drange_orig, drange_net) if min_lod > 0: # compensate for shrink_based_on_lod real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2) real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3) # Optimize loss. G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss( G, D, min_lod, max_lod, real_images_expr, real_labels_var, fake_latents_var, fake_labels_var, **config.loss) G_updates = adam(G_loss, G.trainable_params(), learning_rate=G_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_updates = adam(D_loss, D.trainable_params(), learning_rate=D_lrate, beta1=adam_beta1, beta2=adam_beta2, epsilon=adam_epsilon).items() D_train_fn = theano.function([ real_images_var, real_labels_var, fake_latents_var, fake_labels_var ], [G_loss, D_loss, real_scores_out, fake_scores_out], updates=D_updates, on_unused_input='ignore') G_train_fn = theano.function([fake_latents_var, fake_labels_var], [], updates=G_updates + Gs.updates, on_unused_input='ignore') for idx in xrange(D_training_repeats): mb_reals, mb_labels = training_set.get_random_minibatch( minibatch_size, lod=cur_lod, shrink_based_on_lod=True, labels=True) print "******* test minibatch" print "mb_reals" print idx, D_training_repeats print mb_reals.shape, mb_labels.shape #print mb_reals print "mb_labels" #print mb_labels mb_train_out = D_train_fn( mb_reals, mb_labels, random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) cur_nimg += minibatch_size tick_train_out.append(mb_train_out) G_train_fn(random_latents(minibatch_size, G.input_shape), random_labels(minibatch_size, training_set)) # Fade in D noise if we're close to becoming unstable fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0) fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * ( 1.0 - gdrop_beta) gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)** gdrop_exp) if hasattr(D, 'gdrop_strength'): D.gdrop_strength.set_value(np.float32(gdrop_strength)) # Perform maintenance operations once per tick. if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time tick_start_time = cur_time tick_train_avg = tuple( np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out)) tick_train_out = [] # Print progress. print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % ( (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg) # Visualize generated images. if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels) misc.save_image_grid(snapshot_fake_images, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size) # Save network snapshot every N ticks. if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000: misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg / 1000))) # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) training_set.close() print 'Done.' with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'): pass
def classify(model_path, testing_data_path): labels_1 = [ 'CelebA_real_data', 'ProGAN_generated_data', 'SNGAN_generated_data', 'CramerGAN_generated_data', 'MMDGAN_generated_data' ] labels_2 = [ 'CelebA_real_data', 'ProGAN_seed_0_generated_data ', 'ProGAN_seed_1_generated_data', 'ProGAN_seed_2_generated_data', 'ProGAN_seed_3_generated_data', 'ProGAN_seed_4_generated_data', 'ProGAN_seed_5_generated_data', 'ProGAN_seed_6_generated_data', 'ProGAN_seed_7_generated_data', 'ProGAN_seed_8_generated_data', 'ProGAN_seed_9_generated_data' ] print('Loading network...') C_im = misc.load_network_pkl(model_path) if testing_data_path.endswith('.png') or testing_data_path.endswith( '.jpg'): im = np.array(PIL.Image.open(testing_data_path)).astype( np.float32) / 255.0 if len(im.shape) < 3: im = np.dstack([im, im, im]) if im.shape[2] == 4: im = im[:, :, :3] if im.shape[0] != 128: im = skimage.transform.resize(im, (128, 128)) im = np.transpose(misc.adjust_dynamic_range(im, [0, 1], [-1, 1]), axes=[2, 0, 1]) im = np.reshape(im, [1] + list(im.shape)) logits = C_im.run(im, minibatch_size=1, num_gpus=1, out_dtype=np.float32) idx = np.argmax(np.squeeze(logits)) if logits.shape[1] == len(labels_1): labels = list(labels_1) elif logits.shape[1] == len(labels_2): labels = list(labels_2) print('The input image is predicted as being sampled from %s' % labels[idx]) elif os.path.isdir(testing_data_path): count_dict = None name_list = sorted(os.listdir(testing_data_path)) length = len(name_list) for (count0, name) in enumerate(name_list): im = np.array(PIL.Image.open('%s/%s' % (testing_data_path, name))).astype( np.float32) / 255.0 if len(im.shape) < 3: im = np.dstack([im, im, im]) if im.shape[2] == 4: im = im[:, :, :3] if im.shape[0] != 128: im = skimage.transform.resize(im, (128, 128)) im = np.transpose(misc.adjust_dynamic_range(im, [0, 1], [-1, 1]), axes=[2, 0, 1]) im = np.reshape(im, [1] + list(im.shape)) logits = C_im.run(im, minibatch_size=1, num_gpus=1, out_dtype=np.float32) idx = np.argmax(np.squeeze(logits)) if logits.shape[1] == len(labels_1): labels = list(labels_1) elif logits.shape[1] == len(labels_2): labels = list(labels_2) if count_dict is None: count_dict = {} for label in labels: count_dict[label] = 0 count_dict[labels[idx]] += 1 print( 'Classifying %d/%d images: %s: predicted as being sampled from %s' % (count0, length, name, labels[idx])) for label in labels: print( 'The percentage of images sampled from %s is %d/%d = %.2f%%' % (label, count_dict[label], length, float(count_dict[label]) / float(length) * 100.0))