def transitionAtoB( run_id = 102, # Run ID or network pkl to resume training from, None = start from scratch. snapshot = None, num_gen_noise = 100): #http://cedro3.com/ai/stylegan/ # Initialize TensorFlow. tflib.init_tf() # Load pre-trained network. network_pkl = misc.locate_network_pkl(run_id, snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) # Print network details. Gs.print_layers() # Pick latent vector. rnd = np.random.RandomState(10) # seed = 10 latentsA = rnd.randn(1, Gs.input_shape[1]) for _ in range(num_gen_noise): latents_ = rnd.randn(1, Gs.input_shape[1]) latentsB = rnd.randn(1, Gs.input_shape[1]) num_split = 39 # 2つのベクトルを39分割 for i in range(30): latents = latentsB+(latentsA-latentsB)*i/num_split # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) # Save image. os.makedirs(os.path.join(config.result_dir, 'video'), exist_ok=True) png_filename = os.path.join(os.path.join(config.result_dir, 'video'), 'photo'+'{0:04d}'.format(i)+'.png') PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
def run_snapshot(submit_config, metric_args, run_id, snapshot): ctx = dnnlib.RunContext(submit_config) tflib.init_tf() print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) run_dir = misc.locate_run_dir(run_id) network_pkl = misc.locate_network_pkl(run_dir, snapshot) metric = dnnlib.util.call_func_by_name(**metric_args) print() metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) print() ctx.close()
def test_d(submit_config, resume_run_id, dataset_args, tf_config={}, resume_snapshot=None): ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) latents_1 = tf.placeholder(tf.float32) labels_1 = None training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True) fake_image_1_op = Gs.components.synthesis.get_output_for( w_1, is_validation=True, randomize_noise=False) reals, labels = training_set.get_minibatch_tf() lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) reals = process_reals(reals, lod_in, False, training_set.dynamic_range, [-1, 1]) d_pred_real = D.get_output_for(reals, labels_1) d_pred_fake = D.get_output_for(fake_image_1_op, labels_1) training_set.configure(1, 0) for i in range(15): latents_1_val = np.random.randn(1, *G.input_shape[1:]) # d_pred, fake_image_1 = tflib.run([d_pred_op, fake_image_1_op], feed_dict={latents_1: latents_1_val, lod_in: 0}) d_pred_real_, d_pred_fake_, real_image = tflib.run( [d_pred_real, d_pred_fake, reals], feed_dict={ latents_1: latents_1_val, lod_in: 0 }) print(d_pred_real_, d_pred_fake_) misc.save_mri_image(real_image, os.path.join(submit_config.run_dir, 'real_{}.nii.gz'.format(i)), drange=[-1, 1])
def main( run_id = 101, # Run ID or network pkl to resume training from, None = start from scratch. snapshot = None, # Snapshot index to resume training from, None = autodetect. grid_size = [1,1], image_shrink = 1, image_zoom = 1, duration_sec = 3.0, smoothing_sec = 1.0, mp4 = None, mp4_fps = 30, mp4_codec = 'libx265', mp4_bitrate = '16M', random_seed = 1000, minibatch_size = 8 ): # Initialize TensorFlow. tflib.init_tf() # Load pre-trained network. network_pkl = misc.locate_network_pkl(run_id, snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) mp4 = '%s-lerp.mp4' % network_pkl num_frames = int(np.rint(duration_sec * mp4_fps)) # Print network details. Gs.print_layers() # Pick latent vector. print('Generating latent vectors...') rnd = np.random.RandomState(5) shape = [num_frames, Gs.input_shape[1]] all_latents = rnd.randn(*shape).astype(np.float32) #all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0], mode='wrap') #all_latents /= np.sqrt(np.mean(np.square(all_latents))) all_dlatents = Gs.components.mapping.run(all_latents, None) all_dlatents = scipy.ndimage.gaussian_filter(all_dlatents, [smoothing_sec * mp4_fps] + [0]*2, mode='wrap') print (shape, all_latents.shape, all_dlatents.shape) fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) for idx in range(num_frames): dlatents = all_dlatents[idx] images = Gs.run(dlatents, None, truncation_psi=0.0, randomize_noise=True, output_transform=fmt) png_filename = os.path.join(config.result_dir, 'frame_%s.png' % str(idx).zfill(8)) PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
def mixing(submit_config, resume_run_id, tf_config = {}, resume_snapshot=None): ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) latents_1_val = np.random.randn(1,*G.input_shape[1:]) latents_2_val = np.random.randn(1,*G.input_shape[1:]) # latents_2_val = latents_1_val latents_1 = tf.placeholder(tf.float32) labels_1 = tf.constant([[0,0,0,0,1,0]]) latents_2 = tf.placeholder(tf.float32) labels_2 = tf.constant([[0,0,0,0,1,0]]) w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True) w_2 = Gs.components.mapping.get_output_for(latents_2, labels_2, is_validation=True) # w_1_val = tflib.run(w_1) # w_2_val = tflib.run(w_2) fake_image_1_op = Gs.components.synthesis.get_output_for(w_1, is_validation=True, randomize_noise=False) fake_image_2_op = Gs.components.synthesis.get_output_for(w_2, is_validation=True, randomize_noise=False) fake_image_1 = tflib.run(fake_image_1_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) fake_image_2 = tflib.run(fake_image_2_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_image_1[0], os.path.join(submit_config.run_dir,'fake_image_1.png'), drange=[-1,1]) misc.save_image(fake_image_2[0], os.path.join(submit_config.run_dir,'fake_image_2.png'), drange=[-1,1]) for i in range(15): w_mix = tf.concat([w_1[:,:i],w_2[:,i:]], axis=1) fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False) fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_12_{}.png'.format(i)), drange=[-1,1]) for i in range(15): w_mix = tf.concat([w_2[:,:i],w_1[:,i:]], axis=1) fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False) fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_21_{}.png'.format(i)), drange=[-1,1])
def transitionAtoB_v2( run_id = 102, # Run ID or network pkl to resume training from, None = start from scratch. snapshot = None, num_frames = 20, interpolate_dim = 350): # Initialize TensorFlow. tflib.init_tf() # Load pre-trained network. network_pkl = misc.locate_network_pkl(run_id, snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) # Print network details. Gs.print_layers() # Pick latent vector. rnd = np.random.RandomState(10) # seed = 10 init_latent = rnd.randn(1, Gs.input_shape[1])[0] def apply_latent_fudge(fudge): copy = np.copy(init_latent) copy[interpolate_dim] += fudge return copy interpolate = np.linspace(0., 30., num_frames) - 15 latents = np.array(list(map(apply_latent_fudge, interpolate))) fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) os.makedirs(os.path.join(config.result_dir, os.path.basename(network_pkl).replace(".mp4","")), exist_ok=True) for idx in range(num_frames): # Save image. png_filename = os.path.join(config.result_dir, os.path.basename(network_pkl).replace(".mp4",""), 'frame_'+'{0:04d}'.format(idx)+'.png') PIL.Image.fromarray(images[idx], 'RGB').save(png_filename)
resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. resume_kimg = 0.0 # Assumed training progress at the beginning. Affects reporting and training schedule. #resume_kimg = 5645, resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus
def training_loop( submit_config, Encoder_args={}, E_opt_args={}, D_opt_args={}, E_loss_args={}, D_loss_args={}, lr_args=EasyDict(), tf_config={}, dataset_args=EasyDict(), decoder_pkl=EasyDict(), drange_data=[0, 255], drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. mirror_augment=False, resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? max_iters=150000, E_smoothing=0.999): tflib.init_tf(tf_config) with tf.name_scope('input'): real_train = tf.placeholder(tf.float32, [ submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size ], name='real_image_train') real_test = tf.placeholder(tf.float32, [ submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size ], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs, NE = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.') [0]) // submit_config.batch_size else: print('Constructing networks...') G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl) E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, phase=True, **Encoder_args) start = 0 Gs.print_layers() E.print_layers() D.print_layers() global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global = global_step.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') perceptual_model = PerceptualModel( img_size=[submit_config.image_size, submit_config.image_size], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing) print('building testing graph...') fake_X_val = test(E, Gs, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: E.setup_weight_histograms() D.setup_weight_histograms() cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() print('Optimization starts!!!') for it in range(start, max_iters): feed_dict = {real_train: sess.run(image_batch_train)} sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict) sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) cur_nimg += submit_config.batch_size if it % 100 == 0: print("Iter: %06d kimg: %-8.1f time: %-12s" % (it, cur_nimg / 1000, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % image_snapshot_ticks == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range( batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) samples2 = samples2.transpose(0, 2, 3, 1) batch_images_test = batch_images_test.transpose(0, 2, 3, 1) orin_recon = np.concatenate([batch_images_test, samples2], axis=0) imwrite(immerge(orin_recon, 2, submit_config.batch_size_test), '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg)) if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs, NE), pkl) misc.save_pkl((E, G, D, Gs, NE), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( submit_config, Encoder_args = {}, E_opt_args = {}, D_opt_args = {}, E_loss_args = EasyDict(), D_loss_args = {}, lr_args = EasyDict(), tf_config = {}, dataset_args = EasyDict(), decoder_pkl = EasyDict(), drange_data = [0, 255], drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. mirror_augment = False, resume_run_id = config.ENCODER_PICKLE_DIR, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks = 1, # How often to export image snapshots? network_snapshot_ticks = 4, # How often to export network snapshots? max_iters = 150000): tflib.init_tf(tf_config) with tf.name_scope('input'): real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train') real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) num_layers = Gs.components.synthesis.input_shape[1] E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args) start = 0 E.print_layers(); Gs.print_layers(); D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() print('building testing graph...') fake_X_val = test(E, Gs, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(config.GDRIVE_PATH) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() init_pascal = tf.initialize_variables( [global_step0], name='init_pascal' ) sess.run(init_pascal) print('Optimization starts!!!') for it in range(start, max_iters): batch_images = sess.run(image_batch_train) feed_dict_1 = {real_train: batch_images} _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1) _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1) cur_nimg += submit_config.batch_size if it % 50 == 0: print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % ( it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if it % 500 == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) orin_recon = np.concatenate([batch_images_test, samples2], axis=0) orin_recon = adjust_pixel_range(orin_recon) orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon) # save image to gdrive img_path = os.path.join(config.GDRIVE_PATH, 'images', ('iter_%08d.png' % (cur_nimg))) save_image(img_path, orin_recon) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) # save network snapshot to gdrive pkl_drive = os.path.join(config.GDRIVE_PATH, 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl_drive) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics, resume_id): train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop. G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network. G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg_face') # Options for generator loss. D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss. sched = EasyDict() # Options for TrainingSchedule. grid = EasyDict(size='8k', layout='random') # Options for setup_snapshot_image_grid(). sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). train.resume_pkl = misc.locate_network_pkl(resume_id, None) # train.resume_pkl = './results/00132-stylegan2-11_12-1gpu-config-f/network-snapshot-015041.pkl' train.resume_kimg = int(os.path.basename(train.resume_pkl).split('.')[0].split('-')[2]) train.image_snapshot_ticks = 1, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. train.network_snapshot_ticks = 1, # How often to save network snapshots? None = only save 'networks-final.pkl'. train.data_dir = data_dir train.total_kimg = total_kimg train.mirror_augment = mirror_augment train.image_snapshot_ticks = train.network_snapshot_ticks = 10 sched.G_lrate_base = sched.D_lrate_base = 0.002 sched.minibatch_size_base = 32 sched.minibatch_gpu_base = 4 D_loss.gamma = 10 metrics = [metric_defaults[x] for x in metrics] desc = 'stylegan2' desc += '-' + dataset dataset_args = EasyDict(tfrecord_dir=dataset) assert num_gpus in [1, 2, 4, 8] sc.num_gpus = num_gpus desc += '-%dgpu' % num_gpus assert config_id in _valid_configs desc += '-' + config_id # Configs A-E: Shrink networks to match original StyleGAN. if config_id != 'config-f': G.fmap_base = D.fmap_base = 8 << 10 # Config E: Set gamma to 100 and override G & D architecture. if config_id.startswith('config-e'): D_loss.gamma = 100 if 'Gorig' in config_id: G.architecture = 'orig' if 'Gskip' in config_id: G.architecture = 'skip' # (default) if 'Gresnet' in config_id: G.architecture = 'resnet' if 'Dorig' in config_id: D.architecture = 'orig' if 'Dskip' in config_id: D.architecture = 'skip' if 'Dresnet' in config_id: D.architecture = 'resnet' # (default) # Configs A-D: Enable progressive growing and switch to networks that support it. if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: sched.lod_initial_resolution = 8 sched.G_lrate_base = sched.D_lrate_base = 0.001 sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} sched.minibatch_size_base = 32 # (default) sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} sched.minibatch_gpu_base = 4 # (default) sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} G.synthesis_func = 'G_synthesis_stylegan_revised' D.func_name = 'training.networks_stylegan2.D_stylegan' # Configs A-C: Disable path length regularization. if config_id in ['config-a', 'config-b', 'config-c']: G_loss = EasyDict(func_name='training.loss.G_logistic_ns') # Configs A-B: Disable lazy regularization. if config_id in ['config-a', 'config-b']: train.lazy_regularization = False # Config A: Switch to original StyleGAN networks. if config_id == 'config-a': G = EasyDict(func_name='training.networks_stylegan.G_style') D = EasyDict(func_name='training.networks_stylegan.D_basic') if gamma is not None: D_loss.gamma = gamma sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) kwargs.submit_config = copy.deepcopy(sc) kwargs.submit_config.run_dir_root = result_dir kwargs.submit_config.run_desc = desc dnnlib.submit_run(**kwargs)
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, Encoder_args = {}, E_opt_args = {}, D_opt_args = {}, E_loss_args = EasyDict(), D_loss_args = {}, lr_args = EasyDict(), tf_config = {}, dataset_args = EasyDict(), decoder_pkl = EasyDict(), inversion_pkl = EasyDict(), drange_data = [0, 255], drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. mirror_augment = False, resume_run_id = config.ENCODER_PICKLE_DIR, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks = 1, # How often to export image snapshots? network_snapshot_ticks = 4, # How often to export network snapshots? max_iters = 150000): tflib.init_tf(tf_config) with tf.name_scope('input'): placeholder_real_portraits_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_train') placeholder_real_landmarks_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_train') placeholder_real_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_train') placeholder_landmarks_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_landmarks_shuffled_train') placeholder_real_portraits_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_test') placeholder_real_landmarks_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_test') placeholder_real_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_test') placeholder_real_landmarks_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_shuffled_test') real_split_landmarks = tf.split(placeholder_real_landmarks_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_portraits = tf.split(placeholder_real_portraits_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_shuffled = tf.split(placeholder_real_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_lm_shuffled = tf.split(placeholder_landmarks_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0) placeholder_training_flag = tf.placeholder(tf.string, name='placeholder_training_flag') with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, _, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) # don't use pre-trained discriminator! num_layers = Gs.components.synthesis.input_shape[1] # here we add a new discriminator! D = tflib.Network('D', # name of the network how we call it num_channels=3, resolution=128, label_size=0, #some needed for this build function func_name="training.networks_stylegan.D_basic") # function of that network. more was not passed in d_args! # input is not passed here (just construction - note that we do not call the actual function!). Instead, network will inspect build function and require it for the get_output_for function. print("Created new Discriminator!") E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args) start = 0 Inv, _, _, _ = misc.load_pkl(inversion_pkl.inversion_pkl) E.print_layers(); Gs.print_layers(); D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') Inv_gpu = Inv if gpu == 0 else Inv.clone(Inv.name + '_shadow') perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False) real_portraits_gpu = process_reals(real_split_portraits[gpu], mirror_augment, drange_data, drange_net) shuffled_portraits_gpu = process_reals(real_split_shuffled[gpu], mirror_augment, drange_data, drange_net) real_landmarks_gpu = process_reals(real_split_landmarks[gpu], mirror_augment, drange_data, drange_net) shuffled_landmarks_gpu = process_reals(real_split_lm_shuffled[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, perceptual_model=perceptual_model, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, shuffled_landmarks=shuffled_landmarks_gpu, training_flag=placeholder_training_flag, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, training_flag=placeholder_training_flag, **D_loss_args) # change signature in ... D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() print('building testing graph...') fake_X_val = test(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config) inv_X_val = test_inversion(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config) #sampled_portraits_val = sample_random_portraits(Gs, submit_config.batch_size) #sampled_portraits_val_test = sample_random_portraits(Gs, submit_config.batch_size_test) sess = tf.get_default_session() print('Getting training data...') # x_batch is a batch of (2, ..., ..., ...) records! stack_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') stack_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') stack_batch_train_secondary = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train_secondary') stack_batch_test_secondary = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test_secondary') summary_log = tf.summary.FileWriter(config.getGdrivePath()) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() init_fix = tf.initialize_variables( [global_step0], name='init_fix' ) sess.run(init_fix) print('Optimization starts!!!') # here is the actual training loop: all iterations for it in range(start, max_iters): batch_stacks = sess.run(stack_batch_train) batch_portraits = batch_stacks[:,0,:,:,:] batch_landmarks = batch_stacks[:,1,:,:,:] batch_stacks_secondary = sess.run(stack_batch_train_secondary) batch_shuffled = batch_stacks_secondary[:,0,:,:,:] batch_lm_shuffled = batch_stacks_secondary[:,1,:,:,:] training_flag = "pose" feed_dict_1 = {placeholder_real_portraits_train: batch_portraits, placeholder_real_landmarks_train: batch_landmarks, placeholder_real_shuffled_train:batch_shuffled, placeholder_landmarks_shuffled_train:batch_lm_shuffled, placeholder_training_flag: training_flag} # here we query these encoder- and discriminator losses. as input we provide: batch_stacks = batch of images + landmarks. _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1) _, d_r_, d_f_, d_g_= sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1) cur_nimg += submit_config.batch_size if it % 50 == 0: print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % ( it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if it % 500 == 0: batch_stacks_test = sess.run(stack_batch_test) batch_portraits_test = batch_stacks_test[:,0,:,:,:] batch_landmarks_test = batch_stacks_test[:,1,:,:,:] batch_stacks_test_secondary = sess.run(stack_batch_test_secondary) batch_shuffled_test = batch_stacks_test_secondary[:,0,:,:,:] batch_shuffled_lm_test = batch_stacks_test_secondary[:,1,:,:,:] batch_portraits_test = misc.adjust_dynamic_range(batch_portraits_test.astype(np.float32), [0, 255], [-1., 1.]) batch_landmarks_test = misc.adjust_dynamic_range(batch_landmarks_test.astype(np.float32), [0, 255], [-1., 1.]) batch_shuffled_test = misc.adjust_dynamic_range(batch_shuffled_test.astype(np.float32), [0, 255], [-1., 1.]) batch_shuffled_lm_test = misc.adjust_dynamic_range(batch_shuffled_lm_test.astype(np.float32), [0, 255], [-1., 1.]) # first: input + target landmarks = manipulated image samples_manipulated = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_shuffled_lm_test}) # 2nd: manipulated + original landmarks samples_reconstructed = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: samples_manipulated, placeholder_real_landmarks_test: batch_landmarks_test}) # also: show direct reconstruction samples_direct_rec = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test}) # show results of the inverison portraits_inverted = sess.run(inv_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test}) # show: original portrait, original landmark, diret reconstruction, fake landmark, manipulated, rec. debug_img = np.concatenate([ batch_landmarks_test, # original landmarks batch_portraits_test, # original portraits, samples_direct_rec, # direct batch_shuffled_lm_test, # shuffled landmarks samples_manipulated, # manipulated images samples_reconstructed, portraits_inverted# cycle reconstructed images ], axis=0) debug_img = adjust_pixel_range(debug_img) debug_img = fuse_images(debug_img, row=6, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), debug_img) # save image to gdrive img_path = os.path.join(config.getGdrivePath(), 'images', ('iter_%08d.png' % (cur_nimg))) save_image(img_path, debug_img) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) # save network snapshot to gdrive pkl_drive = os.path.join(config.getGdrivePath(), 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl_drive) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( submit_config, HP_args={}, # Options for the Hessian Penalty. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? interp_snapshot_ticks=20, # How often to generate interpolation visualizations in TensorBoard? network_snapshot_ticks=5, # How often to export network snapshots? network_metric_ticks=5, # How often to evaluate network snapshots on specified metrics? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create a copy of dataset_args for running the metrics: metrics_dataset_args = deepcopy(dataset_args) metrics_dataset_args.shuffle_mb = 0 print('Saving interp videos every %s ticks' % interp_snapshot_ticks) print('Saving network snapshot every %s ticks' % network_snapshot_ticks) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder: hessian_weight = tf.placeholder(tf.float32, name='hessian_weight', shape=[]) G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) reg_ops = [ ] # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, hp_lambda=hessian_weight, HP_args=HP_args, gpu_ix=gpu, lod_in=lod_in, max_lod=training_set.resolution_log2, **G_loss_args) if HP_args.hp_lambda > 0: reg_ops += [G_hessian_penalty] with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss, mutual_info = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, gpu_ix=gpu, infogan_nz=D_args.infogan_nz, **D_loss_args) # print([name for name in D_gpu.trainables.keys()]) # gps = [weight for name, weight in G_gpu.trainables.items()][0] # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0] # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2)) # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2)) # reg_ops.extend([dg, gg, dps, gps]) if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0: reg_ops += [mutual_info] # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero): G_opt.register_gradients( G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables) D_opt.register_gradients( tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up snapshot interpolation...') nz = G.input_shapes[0][1] interp_steps = 24 # Number of frames in the visualization interp_batch_size = 8 # Number of gifs per row of visualization interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps) interp_labels = np.zeros( [interp_steps * interp_batch_size * nz, training_set.label_size], dtype=training_set.label_dtype) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_summary( build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'), 'samples/real'), 0) summary_log.add_summary( build_image_summary( os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), 'samples/Gs'), resume_kimg) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999: print('Generating initial interpolations...') vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=resume_kimg, summary_log=summary_log) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 num_G_grad_steps = 0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch cur_hessian_weight = get_current_hessian_penalty_loss_weight( HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg, HP_args.warmup_nimg) tflib.run( [G_train_op] + reg_ops, { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, hessian_weight: cur_hessian_weight }) num_G_grad_steps += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), autosummary('Progress/hessian_weight', cur_hessian_weight), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) autosummary('Progress/G_grad_steps', num_G_grad_steps) # Save snapshots and fake image samples (for both Gs and G): if cur_tick % image_snapshot_ticks == 0 or done: iter = (cur_nimg // 1000) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes_inst = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) fake_path = os.path.join(submit_config.run_dir, 'fakes%06d.png' % iter) ifake_path = os.path.join(submit_config.run_dir, 'ifakes%06d.png' % iter) misc.save_image_grid(grid_fakes, fake_path, drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_fakes_inst, ifake_path, drange=drange_net, grid_size=grid_size) summary_log.add_summary( build_image_summary(fake_path, 'samples/Gs'), iter) summary_log.add_summary( build_image_summary(ifake_path, 'samples/G'), iter) # Generate/Save Interpolation Visualizations: if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0: vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=cur_nimg // 1000, summary_log=summary_log) # Save snapshot and run metrics: if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1: metrics.run(pkl, dataset_args=metrics_dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % total_kimg)) summary_log.close() ctx.close()
def mixing(resume_run_id, resume_snapshot=None): network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
def training_loop( submit_config, Encoder_args={}, D_args={}, G_args={}, E_opt_args={}, D_opt_args={}, E_loss_args=EasyDict(), D_loss_args={}, lr_args=EasyDict(), tf_config={}, dataset_args=EasyDict(), decoder_pkl=EasyDict(), drange_data=[0, 255], drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. mirror_augment=False, filter=64, # Minimum number of feature maps in any layer. filter_max=512, # Maximum number of feature maps in any layer. resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? d_scale=0.1, # Decide whether to update discriminator. pretrained_D=True, # Whether to use pre trained Discriminator. max_iters=150000): tflib.init_tf(tf_config) with tf.name_scope('Input'): real_train = tf.placeholder(tf.float32, [ submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size ], name='real_image_train') real_test = tf.placeholder(tf.float32, [ submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size ], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) G_style_mod = tflib.Network('G_StyleMod', resolution=submit_config.image_size, label_size=0, **G_args) start = int(network_pkl.split('-')[-1].split('.') [0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, PreD, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) num_layers = Gs.components.synthesis.input_shape[1] E = tflib.Network('E_gpu0', size=submit_config.image_size, filter=filter, filter_max=filter_max, num_layers=num_layers, is_training=True, num_gpus=submit_config.num_gpus, **Encoder_args) OriD = tflib.Network('D_ori', resolution=submit_config.image_size, label_size=0, **D_args) G_style_mod = tflib.Network('G_StyleMod', resolution=submit_config.image_size, label_size=0, **G_args) if pretrained_D: D = PreD else: D = OriD start = 0 Gs_vars_pairs = { name: tflib.run(val) for name, val in Gs.components.synthesis.vars.items() } for g_name, g_val in G_style_mod.vars.items(): tflib.set_vars({g_val: Gs_vars_pairs[g_name]}) E.print_layers() Gs.print_layers() D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) if d_scale > 0: D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('Building Graph on GPU %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu)) D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = G_style_mod if gpu == 0 else G_style_mod.clone( G_style_mod.name + '_shadow') feature_model = PerceptualModel(img_size=[ E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size ], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, feature_model=feature_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) if d_scale > 0: D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() if d_scale > 0: D_train_op = D_opt.apply_updates() print('Building testing graph...') fake_X_val = test(E, G_style_mod, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(submit_config.run_dir) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() print('Optimization starts!!!') for it in range(start, max_iters): batch_images = sess.run(image_batch_train) feed_dict = {real_train: batch_images} _, recon_, adv_, lr = sess.run( [E_train_op, E_loss_rec, E_loss_adv, learning_rate], feed_dict) if d_scale > 0: _, d_r_, d_f_, d_g_ = sess.run( [D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) cur_nimg += submit_config.batch_size run_time = time.time() - start_time iter_time = run_time / (it - start + 1) eta_time = iter_time * (max_iters - it - 1) if it % 50 == 0: print( 'Iter: %06d/%d, lr: %-.8f recon_loss: %-6.4f adv_loss: %-6.4f run_time:%-12s eta_time:%-12s' % (it, max_iters, lr, recon_, adv_, dnnlib.util.format_time(time.time() - start_time), dnnlib.util.format_time(eta_time))) if d_scale > 0: print('d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f ' % (d_r_, d_f_, d_g_)) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % image_snapshot_ticks == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range( batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) recon = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) orin_recon = np.concatenate([batch_images_test, recon], axis=0) orin_recon = adjust_pixel_range(orin_recon) orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image( '%s/iter_%09d.png' % (submit_config.run_dir, cur_nimg), orin_recon) if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%09d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=87, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=2364, # Snapshot index to resume training from, None = autodetect. resume_kimg=2364, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() #----------------------------------------------------------------------------
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device("/gpu:0"): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print("Constructing networks...") G = tflib.Network( name="G", num_inputs=2, # one for latents and one for labels num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network( name="D", num_inputs=int(np.log2(training_set.shape[1])) - 1 + 1, # +1 for labels :) num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") G.print_layers() D.print_layers() print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = (0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) G_opt = tflib.Optimizer(name="TrainG", learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") reals, labels = training_set.get_minibatch_tf() reals = process_reals( reals, mirror_augment, training_set.dynamic_range, drange_net, depth=training_set.resolution_log2 - 1, ) with tf.name_scope("G_loss"): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope("D_loss"): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) print("Setting up snapshot image grid...") grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) print("Setting up run dir...") fake_multi_scale_dirs = [ os.path.join(submit_config.run_dir, str(2**res) + "x" + str(2**res)) for res in range(2, 2 + len(grid_fakes)) ] misc.save_image_grid( grid_reals, os.path.join(submit_config.run_dir, "reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training...\n") ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg # configure the training_set to a proper minibatch size training_set.configure(sched.minibatch_size // submit_config.num_gpus) while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lrate_in: sched.D_lrate, minibatch_in: sched.minibatch_size }, ) cur_nimg += sched.minibatch_size tflib.run( [G_train_op], { lrate_in: sched.G_lrate, minibatch_in: sched.minibatch_size }, ) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % (cur_nimg // 1000)) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, "network-snapshot-%06d.pkl" % (cur_nimg // 1000), ) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, "network-final.pkl")) summary_log.close() ctx.close()