def train_progressive_gan( G_smoothing=0.999, # Exponential running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. maintenance_start_time = time.time() training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs, E = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G) D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D) E = tfutil.Network('E', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.E) Gs = G.clone('Gs') Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing) G.print_layers() D.print_layers() E.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // config.num_gpus reals, labels = training_set.get_minibatch_tf() reals_split = tf.split(reals, config.num_gpus) labels_split = tf.split(labels, config.num_gpus) G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt) D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt) E_opt = tfutil.Optimizer(name='TrainE', learning_rate=lrate_in, **config.E_opt) for gpu in range(config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in), tf.assign(E_gpu.find_var('lod'), lod_in) ] reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net) labels_gpu = labels_split[gpu] with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, E=E_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, E=E_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss) with tf.name_scope('E_loss'), tf.control_dependencies( lod_assign_ops): E_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, E=E_gpu, opt=E_opt, training_set=training_set, reals=reals_gpu, minibatch_size=minibatch_split, **config.E_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) E_opt.register_gradients(tf.reduce_mean(E_loss), E_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() E_train_op = E_opt.apply_updates() #sys.exit(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid( G, training_set, **config.grid) sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched) grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() print('Training...') cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time prev_lod = -1.0 while cur_nimg < total_kimg * 1000: # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() E_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for repeat in range(minibatch_repeats): for _ in range(D_repeats): tfutil.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tfutil.run( [G_train_op, E_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0), tfutil.autosummary('Progress/lod', sched.lod), tfutil.autosummary('Progress/minibatch', sched.minibatch), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), tfutil.autosummary('Timing/maintenance_sec', maintenance_time))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) misc.save_all_res(training_set.shape[1], Gs, result_subdir, 50, minibatch_size=sched.minibatch // config.num_gpus) if cur_tick % network_snapshot_ticks == 0 or done: misc.save_pkl( (G, D, Gs, E), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) # Record start time of the next tick. tick_start_time = time.time() # Write final results. misc.save_pkl((G, D, Gs, E), os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_detector( minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=1, # Total length of the training, measured in thousands of real images. drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. snapshot_size=16, # Size of the snapshot image snapshot_ticks=2**13, # Number of images before maintenance image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=1, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. maintenance_start_time = time.time() # Load the datasets training_set = dataset.load_dataset(tfrecord=config.tfrecord_train, verbose=True, **config.dataset) testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test, verbose=True, repeat=False, shuffle_mb=0, **config.dataset) testing_set_len = len(testing_set) # TODO: data augmentation # TODO: testing set # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: # TODO: save methods network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) N = misc.load_pkl(network_pkl) else: print('Constructing the network...' ) # TODO: better network (like lod-wise network) N = tfutil.Network('N', num_channels=training_set.shape[0], resolution=training_set.shape[1], **config.N) N.print_layers() print('Building TensorFlow graph...') # Training set up with tf.name_scope('Inputs'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) # minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) reals, labels, bboxes = training_set.get_minibatch_tf( ) # TODO: increase the size of the batch by several loss computation and mean N_opt = tfutil.Optimizer(name='TrainN', learning_rate=lrate_in, **config.N_opt) with tf.device('/gpu:0'): reals, labels, gt_outputs, gt_ref = pre_process( reals, labels, bboxes, training_set.dynamic_range, [0, training_set.shape[-2]], drange_net) with tf.name_scope('N_loss'): # TODO: loss inadapted N_loss = tfutil.call_func_by_name(N=N, reals=reals, gt_outputs=gt_outputs, gt_ref=gt_ref, **config.N_loss) N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables) N_train_op = N_opt.apply_updates() # Testing set up with tf.device('/gpu:0'): test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf( ) test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process( test_reals_tf, test_labels_tf, test_bboxes_tf, testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net) with tf.name_scope('N_test_loss'): test_loss = tfutil.call_func_by_name(N=N, reals=test_reals_tf, gt_outputs=test_gt_outputs_tf, gt_ref=test_gt_ref_tf, is_training=False, **config.N_loss) print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: N.setup_weight_histograms() test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size) misc.save_img_bboxes(test_reals, test_bboxes, os.path.join(result_subdir, 'reals.png'), snapshot_size, adjust_range=False) test_reals = misc.adjust_dynamic_range(test_reals, training_set.dynamic_range, drange_net) test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size) misc.save_img_bboxes(test_reals, test_preds, os.path.join(result_subdir, 'fakes.png'), snapshot_size) print('Training...') if resume_run_id is None: tfutil.run(tf.global_variables_initializer()) cur_nimg = 0 cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch) _train_loss = 0 while cur_nimg < total_kimg * 1000: # Run training ops. # for _ in range(minibatch_repeats): _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate}) _train_loss += loss cur_nimg += sched.minibatch # Perform maintenance tasks once per tick. if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0 and cur_nimg > 0): cur_tick += 1 cur_time = time.time() # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 _train_loss = _train_loss / (cur_nimg - tick_start_nimg) tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time testing_set.configure(sched.minibatch) _test_loss = 0 # testing_set_len = 1 # TMP for _ in range(0, testing_set_len, sched.minibatch): _test_loss += tfutil.run(test_loss) _test_loss /= testing_set_len # Report progress. # TODO: improved report display print( 'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/maintenance', maintenance_time), tfutil.autosummary('TrainN/train_loss', _train_loss), tfutil.autosummary('TrainN/test_loss', _test_loss))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) if cur_tick % image_snapshot_ticks == 0: test_bboxes, test_refs = N.run(test_reals, minibatch_size=snapshot_size) misc.save_img_bboxes_ref( test_reals, test_bboxes, test_refs, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), snapshot_size) if cur_tick % network_snapshot_ticks == 0: misc.save_pkl( N, os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) _train_loss = 0 # Record start time of the next tick. tick_start_time = time.time() # Write final results. # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_progressive_gan( G_smoothing=0.999, # Exponential running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. maintenance_start_time = time.time() training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) unlabeled_training_set = dataset.load_dataset( data_dir=config.unlabeled_data_dir, verbose=True, **config.unlabeled_dataset) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') print("Training-Set Label Size: ", training_set.label_size) print("Unlabeled-Training-Set Label Size: ", unlabeled_training_set.label_size) G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G) D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D) Gs = G.clone('Gs') Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing) G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // config.num_gpus reals, labels = training_set.get_minibatch_tf() unlabeled_reals, _ = unlabeled_training_set.get_minibatch_tf() reals_split = tf.split(reals, config.num_gpus) unlabeled_reals_split = tf.split(unlabeled_reals, config.num_gpus) labels_split = tf.split(labels, config.num_gpus) G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt) G_opt_pggan = tfutil.Optimizer(name='TrainG_pggan', learning_rate=lrate_in, **config.G_opt) D_opt_pggan = tfutil.Optimizer(name='TrainD_pggan', learning_rate=lrate_in, **config.D_opt) D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt) print("CUDA_VISIBLE_DEVICES: ", os.environ['CUDA_VISIBLE_DEVICES']) for gpu in range(config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net) unlabeled_reals_gpu = process_reals( unlabeled_reals_split[gpu], lod_in, mirror_augment, unlabeled_training_set.dynamic_range, drange_net) labels_gpu = labels_split[gpu] with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, unlabeled_reals=unlabeled_reals_gpu, **config.G_loss) with tf.name_scope('G_loss_pggan'), tf.control_dependencies( lod_assign_ops): G_loss_pggan = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss_pggan) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, unlabeled_reals=unlabeled_reals_gpu, **config.D_loss) with tf.name_scope('D_loss_pggan'), tf.control_dependencies( lod_assign_ops): D_loss_pggan = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, unlabeled_reals=unlabeled_reals_gpu, **config.D_loss_pggan) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) G_opt_pggan.register_gradients(tf.reduce_mean(G_loss_pggan), G_gpu.trainables) D_opt_pggan.register_gradients(tf.reduce_mean(D_loss_pggan), D_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) print('GPU %d loaded!' % gpu) G_train_op = G_opt.apply_updates() G_train_op_pggan = G_opt_pggan.apply_updates() D_train_op_pggan = D_opt_pggan.apply_updates() D_train_op = D_opt.apply_updates() print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid( G, training_set, **config.grid) sched = TrainingSchedule(total_kimg * TrainingSpeedInt, training_set, **config.sched) grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size) summary_log = tf.compat.v1.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() print("Start Time: ", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) print('Training...') cur_nimg = int(resume_kimg * TrainingSpeedInt) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time prev_lod = -1.0 while cur_nimg < total_kimg * TrainingSpeedInt: # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) sched2 = TrainingSchedule(cur_nimg, unlabeled_training_set, **config.sched) training_set.configure(sched.minibatch, sched.lod) unlabeled_training_set.configure(sched2.minibatch, sched2.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() G_opt_pggan.reset_optimizer_state() D_opt_pggan.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for repeat in range(minibatch_repeats): for _ in range(D_repeats): # Run the Pggan loss if lod != 0 else run SSL loss with feature matching if sched.lod == 0: tfutil.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) else: tfutil.run( [D_train_op_pggan, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch #tmp = min(tick_start_nimg + sched.tick_kimg * TrainingSpeedInt, total_kimg * TrainingSpeedInt) #print("Tick progress: {}/{}".format(cur_nimg, tmp), end="\r", flush=True) # Run the Pggan loss if lod != 0 else run SSL loss with feature matching if sched.lod == 0: tfutil.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) else: tfutil.run( [G_train_op_pggan], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * TrainingSpeedInt) if cur_nimg >= tick_start_nimg + sched.tick_kimg * TrainingSpeedInt or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / TrainingSpeedFloat tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f date %s' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / TrainingSpeedFloat), tfutil.autosummary('Progress/lod', sched.lod), tfutil.autosummary('Progress/minibatch', sched.minibatch), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), tfutil.autosummary( 'Timing/maintenance_sec', maintenance_time), datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) ####################### # VALIDATION ACCURACY # ####################### # example ndim = 512 for an image that is 512x512 pixels # All images for SSL-PGGAN must be square ndim = 256 correct = 0 guesses = 0 dir_tuple = (config.validation_dog, config.validation_cat) # If guessed the wrong class seeing if there is a bias FP_RATE = [[0], [0]] # For each class for indx, directory in enumerate(dir_tuple): # Go through every image that needs to be tested for filename in os.listdir(directory): guesses += 1 #tensor = np.zeros((1, 3, 512, 512)) print(filename) img = np.asarray(PIL.Image.open(directory + filename)).reshape( 3, ndim, ndim) img = np.expand_dims( img, axis=0) # makes the image (1,3,512,512) K_logits_out, fake_logit_out, features_out = test_discriminator( D, img) #print("K Logits Out:",K_logits_out.eval()) sample_probs = tf.nn.softmax(K_logits_out) #print("Softmax Output:", sample_probs.eval()) label = np.argmax(sample_probs.eval()[0], axis=0) if label == indx: correct += 1 else: FP_RATE[indx][0] += 1 print("-----------------------------------") print("GUESSED LABEL: ", label) print("CORRECT LABEL: ", indx) validation = (correct / guesses) print("Total Correct: ", correct, "\n", "Total Guesses: ", guesses, "\n", "Percent correct: ", validation) print("False Positives: Dog, Cat", FP_RATE) print() tfutil.autosummary('Accuracy/Validation', (correct / guesses)) tfutil.save_summaries(summary_log, cur_nimg) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg // TrainingSpeedInt)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done: misc.save_pkl((G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // TrainingSpeedInt))) # Record start time of the next tick. tick_start_time = time.time() # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
target = tf.Variable(np.zeros((1, 3, 1024, 1024)), dtype=np.float32) # Define network Gs = load_network() fea_ext_name_to_var_name = {'inception_v3': 'InceptionV3'} feature_extractor = get_feature_extractor(args) seed_label = np.zeros([seed_latent.shape[0]] + Gs.input_shapes[1][1:], dtype=np.float32) # Placeholder noise_loss, fake_images_out = get_nosie_output(args, Gs, seed_latent, seed_label, target, feature_extractor) # Apply optimizers lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) noise_opt = tfutil.Optimizer(name='Noise', learning_rate=lrate_in, **config.G_opt) noise_opt.register_gradients(tf.reduce_mean(noise_loss), OrderedDict([('latents', seed_latent)])) noise_update_op = noise_opt.apply_updates() # Get TF Session sess = tf.get_default_session() # Initialize variables uninstizalized_vars = [seed_latent, target] init_op = tf.variables_initializer(uninstizalized_vars) sess.run([init_op]) if args.feature_extractor is not None and args.feature_extractor != 'D': extractor_variables = slim.get_variables_to_restore() extractor_variables = [
def train_classifier( smoothing=0.999, # Exponential running average of encoder weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. lr_mirror_augment=True, # Enable mirror augment? ud_mirror_augment=False, # Enable up-down mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to export image snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False ): # Include weight histograms in the tfevents file? maintenance_start_time = time.time() training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.training_set) validation_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.validation_set) network_snapshot_ticks = total_kimg // 100 # How often to export network snapshots? # Construct networks. with tf.device('/gpu:0'): try: network_pkl = misc.locate_network_pkl() resume_kimg, resume_time = misc.resume_kimg_time(network_pkl) print('Loading networks from "%s"...' % network_pkl) EG, D_rec, EGs = misc.load_pkl(network_pkl) except: print('Constructing networks...') resume_kimg = 0.0 resume_time = 0.0 EG = tfutil.Network('EG', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.EG) D_rec = tfutil.Network('D_rec', num_channels=training_set.shape[0], resolution=training_set.shape[1], **config.D_rec) EGs = EG.clone('EGs') EGs_update_op = EGs.setup_as_moving_average_of(EG, beta=smoothing) EG.print_layers() D_rec.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // config.num_gpus reals, labels = training_set.get_minibatch_tf() reals_split = tf.split(reals, config.num_gpus) labels_split = tf.split(labels, config.num_gpus) EG_opt = tfutil.Optimizer(name='TrainEG', learning_rate=lrate_in, **config.EG_opt) D_rec_opt = tfutil.Optimizer(name='TrainD_rec', learning_rate=lrate_in, **config.D_rec_opt) for gpu in range(config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): EG_gpu = EG if gpu == 0 else EG.clone(EG.name + '_shadow_%d' % gpu) D_rec_gpu = D_rec if gpu == 0 else D_rec.clone(D_rec.name + '_shadow_%d' % gpu) reals_fade_gpu, reals_orig_gpu = process_reals( reals_split[gpu], lod_in, lr_mirror_augment, ud_mirror_augment, training_set.dynamic_range, drange_net) labels_gpu = labels_split[gpu] with tf.name_scope('EG_loss'): EG_loss = tfutil.call_func_by_name(EG=EG_gpu, D_rec=D_rec_gpu, reals_orig=reals_orig_gpu, labels=labels_gpu, **config.EG_loss) with tf.name_scope('D_rec_loss'): D_rec_loss = tfutil.call_func_by_name( EG=EG_gpu, D_rec=D_rec_gpu, D_rec_opt=D_rec_opt, minibatch_size=minibatch_split, reals_orig=reals_orig_gpu, **config.D_rec_loss) EG_opt.register_gradients(tf.reduce_mean(EG_loss), EG_gpu.trainables) D_rec_opt.register_gradients(tf.reduce_mean(D_rec_loss), D_rec_gpu.trainables) EG_train_op = EG_opt.apply_updates() D_rec_train_op = D_rec_opt.apply_updates() print('Setting up snapshot image grid...') grid_size, train_reals, train_labels = setup_snapshot_image_grid( training_set, drange_net, [450, 10], **config.grid) grid_size, val_reals, val_labels = setup_snapshot_image_grid( validation_set, drange_net, [450, 10], **config.grid) sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched) train_recs, train_fingerprints, train_logits = EGs.run( train_reals, minibatch_size=sched.minibatch // config.num_gpus) train_preds = np.argmax(train_logits, axis=1) train_gt = np.argmax(train_labels, axis=1) train_acc = np.float32(np.sum(train_gt == train_preds)) / np.float32( len(train_gt)) print('Training Accuracy = %f' % train_acc) val_recs, val_fingerprints, val_logits = EGs.run( val_reals, minibatch_size=sched.minibatch // config.num_gpus) val_preds = np.argmax(val_logits, axis=1) val_gt = np.argmax(val_labels, axis=1) val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(len(val_gt)) print('Validation Accuracy = %f' % val_acc) print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) misc.save_image_grid(train_reals[::30, :, :, :], os.path.join(result_subdir, 'train_reals.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(train_recs[::30, :, :, :], os.path.join(result_subdir, 'train_recs-init.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(train_fingerprints[::30, :, :, :], os.path.join(result_subdir, 'train_fingerrints-init.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(val_reals[::30, :, :, :], os.path.join(result_subdir, 'val_reals.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(val_recs[::30, :, :, :], os.path.join(result_subdir, 'val_recs-init.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(val_fingerprints[::30, :, :, :], os.path.join(result_subdir, 'val_fingerrints-init.png'), drange=drange_net, grid_size=[15, 10]) est_fingerprints = np.transpose( EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1]) misc.save_image_grid( est_fingerprints, os.path.join(result_subdir, 'est_fingerrints-init.png'), drange=[np.amin(est_fingerprints), np.amax(est_fingerprints)], grid_size=[est_fingerprints.shape[0], 1]) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: EG.setup_weight_histograms() D_rec.setup_weight_histograms() print('Training...') cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time prev_lod = -1.0 while cur_nimg < total_kimg * 1000: # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): EG_opt.reset_optimizer_state() D_rec_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for repeat in range(minibatch_repeats): tfutil.run( [D_rec_train_op], { lod_in: sched.lod, lrate_in: sched.lrate, minibatch_in: sched.minibatch }) tfutil.run( [EG_train_op], { lod_in: sched.lod, lrate_in: sched.lrate, minibatch_in: sched.minibatch }) tfutil.run([EGs_update_op], {}) cur_nimg += sched.minibatch # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f resolution %-4d minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0), tfutil.autosummary('Progress/lod', sched.lod), tfutil.autosummary('Progress/resolution', sched.resolution), tfutil.autosummary('Progress/minibatch', sched.minibatch), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), tfutil.autosummary('Timing/maintenance_sec', maintenance_time))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) # Print accuracy. if cur_tick % image_snapshot_ticks == 0 or done: train_recs, train_fingerprints, train_logits = EGs.run( train_reals, minibatch_size=sched.minibatch // config.num_gpus) train_preds = np.argmax(train_logits, axis=1) train_gt = np.argmax(train_labels, axis=1) train_acc = np.float32(np.sum( train_gt == train_preds)) / np.float32(len(train_gt)) print('Training Accuracy = %f' % train_acc) val_recs, val_fingerprints, val_logits = EGs.run( val_reals, minibatch_size=sched.minibatch // config.num_gpus) val_preds = np.argmax(val_logits, axis=1) val_gt = np.argmax(val_labels, axis=1) val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32( len(val_gt)) print('Validation Accuracy = %f' % val_acc) misc.save_image_grid(train_recs[::30, :, :, :], os.path.join(result_subdir, 'train_recs-final.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(train_fingerprints[::30, :, :, :], os.path.join( result_subdir, 'train_fingerrints-final.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(val_recs[::30, :, :, :], os.path.join(result_subdir, 'val_recs-final.png'), drange=drange_net, grid_size=[15, 10]) misc.save_image_grid(val_fingerprints[::30, :, :, :], os.path.join(result_subdir, 'val_fingerrints-final.png'), drange=drange_net, grid_size=[15, 10]) est_fingerprints = np.transpose( EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1]) misc.save_image_grid(est_fingerprints, os.path.join(result_subdir, 'est_fingerrints-final.png'), drange=[ np.amin(est_fingerprints), np.amax(est_fingerprints) ], grid_size=[est_fingerprints.shape[0], 1]) if cur_tick % network_snapshot_ticks == 0 or done: misc.save_pkl( (EG, D_rec, EGs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) # Record start time of the next tick. tick_start_time = time.time() # Write final results. misc.save_pkl((EG, D_rec, EGs), os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_progressive_gan( G_smoothing=0.999, # Exponential running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? compute_fid_score=False, # Compute FID during training once sched.lod=0.0 minimum_fid_kimg=0, # Compute FID after fid_snapshot_ticks=1, # How often to compute FID fid_patience=2, # When to end training based on FID save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. result_subdir="./"): maintenance_start_time = time.time() training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) # Construct networks. with tf.device('/gpu:0'): if resume_run_id != "None": network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) resume_pkl_name = os.path.splitext( os.path.basename(network_pkl))[0] try: resume_kimg = int(resume_pkl_name.split('-')[-1]) print('** Setting resume kimg to', resume_kimg, flush=True) except: print('** Keeping resume kimg as:', resume_kimg, flush=True) print('Loading networks from "%s"...' % network_pkl, flush=True) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...', flush=True) G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G) D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D) Gs = G.clone('Gs') Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing) G.print_layers() D.print_layers() print('Building TensorFlow graph...', flush=True) with tf.name_scope('Inputs'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // config.num_gpus reals, labels = training_set.get_minibatch_tf() reals_split = tf.split(reals, config.num_gpus) labels_split = tf.split(labels, config.num_gpus) G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt) D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt) for gpu in range(config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net) labels_gpu = labels_split[gpu] with tf.name_scope('G_loss'): G_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss) with tf.name_scope('D_loss'): D_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() print('Setting up snapshot image grid...', flush=True) grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid( G, training_set, **config.grid) sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched) grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) print('Setting up result dir...', flush=True) result_subdir = misc.create_result_subdir(config.result_dir, config.desc) misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() print('Training...', flush=True) # FID patience parameters: fid_list = [] fid_steps = 0 fid_stop = False fid_patience_step = 0 cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time prev_lod = -1.0 while cur_nimg < total_kimg * 1000: # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for repeat in range(minibatch_repeats): for _ in range(D_repeats): tfutil.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tfutil.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time if (compute_fid_score == True) and (cur_tick % fid_snapshot_ticks == 0) and (sched.lod == 0.0) and ( cur_nimg >= minimum_fid_kimg * 1000): fid = compute_fid(Gs=Gs, minibatch_size=sched.minibatch, dataset_obj=training_set, iter_number=cur_nimg / 1000, lod=0.0, num_images=10000, printing=False) fid_list.append(fid) # Report progress without FID. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0), tfutil.autosummary('Progress/lod', sched.lod), tfutil.autosummary('Progress/minibatch', sched.minibatch), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), tfutil.autosummary('Timing/maintenance_sec', maintenance_time)), flush=True) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) # Save image snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) # Save network snapshots if cur_tick % network_snapshot_ticks == 0 or done or ( compute_fid_score == True) and (cur_tick % fid_snapshot_ticks == 0) and ( cur_nimg >= minimum_fid_kimg * 1000): misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) # End training based on FID patience if (compute_fid_score == True) and (cur_tick % fid_snapshot_ticks == 0) and (sched.lod == 0.0) and ( cur_nimg >= minimum_fid_kimg * 1000): fid_patience_step += 1 if len(fid_list) == 1: fid_patience_step = 0 misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final-full-conv.pkl')) print( "Save network-final-full-conv for FID: %.3f at kimg %-8.1f." % (fid_list[-1], cur_nimg // 1000), flush=True) else: if fid_list[-1] < np.min(fid_list[:-1]): fid_patience_step = 0 misc.save_pkl( (G, D, Gs), os.path.join(result_subdir, 'network-final-full-conv.pkl')) print( "Save network-final-full-conv for FID: %.3f at kimg %-8.1f." % (fid_list[-1], cur_nimg // 1000), flush=True) else: print("No improvement for FID: %.3f at kimg %-8.1f." % (fid_list[-1], cur_nimg // 1000), flush=True) if fid_patience_step == fid_patience: fid_stop = True print("Training stopped due to FID early-stopping.", flush=True) cur_nimg = total_kimg * 1000 # Record start time of the next tick. tick_start_time = time.time() # Write final results. # Save final only if FID-Stopping has not happend: if fid_stop == False: fid = compute_fid(Gs=Gs, minibatch_size=sched.minibatch, dataset_obj=training_set, iter_number=cur_nimg / 1000, lod=0.0, num_images=10000, printing=False) print("Final FID: %.3f at kimg %-8.1f." % (fid, cur_nimg // 1000), flush=True) ### save final FID to .csv file in result_parent_dir csv_file = os.path.join( os.path.dirname(os.path.dirname(result_subdir)), "results_full_conv.csv") list_to_append = [ result_subdir.split("/")[-2] + "/" + result_subdir.split("/")[-1], fid ] with open(csv_file, 'a') as f_object: writer_object = writer(f_object) writer_object.writerow(list_to_append) f_object.close() misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final-full-conv.pkl')) print("Save network-final-full-conv.", flush=True) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def recovery(name, pkl_path1, pkl_path2, out_dir, target_latents_dir, num_init=20, num_total_sample=100, image_shrink=1, random_seed=2020, minibatch_size=1, noise_sigma=0): # misc.init_output_logging() # np.random.seed(random_seed) # print('Initializing TensorFlow...') # os.environ.update(config.env) # tfutil.init_tf(config.tf_config) print('num_init:' + str(num_init)) # load sorce model print('Loading network1...' + pkl_path1) _, _, G_sorce = misc.load_network_pkl(pkl_path1) # load target model print('Loading network2...' + pkl_path2) _, _, G_target = misc.load_network_pkl(pkl_path2) # load Gt Gt = tfutil.Network('Gt', num_samples=num_init, num_channels=3, resolution=128, func='networks.G_recovery') latents = misc.random_latents(num_init, Gt, random_state=None) labels = np.zeros([latents.shape[0], 0], np.float32) Gt.copy_vars_from_with_input(G_target, latents) # load Gs Gs = tfutil.Network('Gs', num_samples=num_init, num_channels=3, resolution=128, func='networks.G_recovery') Gs.copy_vars_from_with_input(G_sorce, latents) out_dir = os.path.join(out_dir, name) if not os.path.exists(out_dir): os.mkdir(out_dir) def G_loss(G, target_images): tmp_latents = tfutil.run(G.trainables['Input/weight']) G_out = G.get_output_for(tmp_latents, labels, is_training=True) G_out = rescale_output(G_out) return tf.losses.mean_squared_error(target_images, G_out) z_init = [] z_recovered = [] #load target z if target_latents_dir is not None: print('using latents:' + target_latents_dir) pre_latents = np.load(target_latents_dir) for k in range(num_total_sample): result_dir = os.path.join(out_dir, str(k) + '.png') #============sample target image if target_latents_dir is not None: latent = pre_latents[k] else: latents = misc.random_latents(1, Gs, random_state=None) latent = latents[0] z_init.append(latent) latents = np.zeros((num_init, 512)) for i in range(num_init): latents[i] = latent Gt.change_input(inputs=latents) #================add_noise target_images = Gt.get_output_for(latents, labels, is_training=False) target_images_tf = rescale_output(target_images) target_images = tfutil.run(target_images_tf) target_images_noise = addGaussianNoise(target_images, sigma=noise_sigma) target_images_noise = tf.cast(target_images_noise, dtype='float32') target_images = target_images_noise #=============select random start point latents_2 = misc.random_latents(num_init, Gs, random_state=None) Gs.change_input(inputs=latents_2) #==============define loss&optimizer regularizer = tf.abs(tf.norm(latents_2) - np.sqrt(512)) loss = G_loss(G=Gs, target_images=target_images) # + regularizer # init_var = OrderedDict([('Input/weight',Gs.trainables['Input/weight'])]) # decayed_lr = tf.train.exponential_decay(0.1,500, 50, 0.5, staircase=True) G_opt = tfutil.Optimizer(name='latent_recovery', learning_rate=0.01) G_opt.register_gradients(loss, Gs.trainables) G_train_op = G_opt.apply_updates() #===========recovery========== EPOCH = 500 losses = [] losses.append(tfutil.run(loss)) for i in range(EPOCH): G_opt.reset_optimizer_state() tfutil.run([G_train_op]) ######## learned_latent = tfutil.run(Gs.trainables['Input/weight']) result_images = Gs.run(learned_latent, labels, minibatch_size=config.num_gpus * 256, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.float32) sample_losses = [] tmp_latents = tfutil.run(Gs.trainables['Input/weight']) G_out = Gs.get_output_for(tmp_latents, labels, is_training=True) G_out = rescale_output(G_out) for i in range(num_init): loss = tf.losses.mean_squared_error(target_images[i], G_out[i]) sample_losses.append(tfutil.run(loss)) #========save best optimized image plt.subplot(1, 2, 1) plt.imshow(tfutil.run(target_images)[0].transpose(1, 2, 0) / 255.0) plt.subplot(1, 2, 2) plt.imshow(result_images[np.argmin(sample_losses)].transpose(1, 2, 0) / 255.0) plt.savefig(result_dir) #========store optimized z z_recovered.append(tmp_latents) #=========save losses # loss=min(sample_losses) with open(out_dir + "/losses.txt", "a") as f: for loss in sample_losses: f.write(str(loss) + ' ') f.write('\n') np.save(out_dir + '/z_init', np.array(z_init)) np.save(out_dir + '/z_re', np.array(z_recovered))
def train_progressive_gan( G_smoothing=0.999, # Exponential running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. maintenance_start_time = time.time() training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) #resume_run_id = '/dresden/users/mk1391/evl/pggan_logs/logs_celeba128cc/fsg16_results_0/000-pgan-celeba-preset-v2-2gpus-fp32/network-snapshot-010211.pkl' resume_with_new_nets = False # Construct networks. with tf.device('/gpu:0'): if resume_run_id is None or resume_with_new_nets: print('Constructing networks...') G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G) D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D) Gs = G.clone('Gs') if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) rG, rD, rGs = misc.load_pkl(network_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing) G.print_layers() D.print_layers() ### pyramid draw fsg (comment out for actual training to happen) #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'pggan_fsg_draw.png')) #print('>>> done printing fsgs.') #return print('Building TensorFlow graph...') with tf.name_scope('Inputs'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // config.num_gpus reals, labels = training_set.get_minibatch_tf() reals_split = tf.split(reals, config.num_gpus) labels_split = tf.split(labels, config.num_gpus) G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt) D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt) for gpu in range(config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net) labels_gpu = labels_split[gpu] with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = tfutil.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid( G, training_set, **config.grid) ### shift reals print('>>> reals shape: ', grid_reals.shape) fc_x = 0.5 fc_y = 0.5 im_size = grid_reals.shape[-1] kernel_loc = 2.*np.pi*fc_x * np.arange(im_size).reshape((1, 1, im_size)) + \ 2.*np.pi*fc_y * np.arange(im_size).reshape((1, im_size, 1)) kernel_cos = np.cos(kernel_loc) kernel_sin = np.sin(kernel_loc) reals_t = (grid_reals / 255.) * 2. - 1 reals_t *= kernel_cos grid_reals_sh = np.rint( (reals_t + 1.) * 255. / 2.).clip(0, 255).astype(np.uint8) ### end shift reals sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched) grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) ### fft drawing #sys.path.insert(1, '/home/mahyar/CV_Res/ganist') #from fig_draw import apply_fft_win #data_size = 1000 #latents = np.random.randn(data_size, *Gs.input_shapes[0][1:]) #labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:]) #g_samples = Gs.run(latents, labels, minibatch_size=sched.minibatch//config.num_gpus) #g_samples = g_samples.transpose(0, 2, 3, 1) #print('>>> g_samples shape: {}'.format(g_samples.shape)) #apply_fft_win(g_samples, 'fft_pggan_hann.png') ### end fft drawing print('Setting up result dir...') result_subdir = misc.create_result_subdir(config.result_dir, config.desc) misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) ### drawing shifted real images misc.save_image_grid(grid_reals_sh, os.path.join(result_subdir, 'reals_sh.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size) ### drawing shifted fake images misc.save_image_grid(grid_fakes * kernel_cos, os.path.join(result_subdir, 'fakes%06d_sh.png' % 0), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(result_subdir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() #### True cosine fft eval #fft_data_size = 1000 #im_size = training_set.shape[1] #freq_centers = [(64/128., 64/128.)] #true_samples = sample_true(training_set, fft_data_size, dtype=training_set.dtype, batch_size=32).transpose(0, 2, 3, 1) / 255. * 2. - 1. #true_fft, true_fft_hann, true_hist = cosine_eval(true_samples, 'true', freq_centers, log_dir=result_subdir) #fractal_eval(true_samples, f'koch_snowflake_true', result_subdir) print('Training...') cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() train_start_time = tick_start_time - resume_time prev_lod = -1.0 while cur_nimg < total_kimg * 1000: # Choose training parameters and configure training ops. sched = TrainingSchedule(cur_nimg, training_set, **config.sched) training_set.configure(sched.minibatch, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for repeat in range(minibatch_repeats): for _ in range(D_repeats): tfutil.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tfutil.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = cur_time - tick_start_time total_time = cur_time - train_start_time maintenance_time = tick_start_time - maintenance_start_time maintenance_start_time = cur_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (tfutil.autosummary('Progress/tick', cur_tick), tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0), tfutil.autosummary('Progress/lod', sched.lod), tfutil.autosummary('Progress/minibatch', sched.minibatch), misc.format_time( tfutil.autosummary('Timing/total_sec', total_time)), tfutil.autosummary('Timing/sec_per_tick', tick_time), tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), tfutil.autosummary('Timing/maintenance_sec', maintenance_time))) tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) tfutil.save_summaries(summary_log, cur_nimg) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) ### drawing shifted fake images misc.save_image_grid( grid_fakes * kernel_cos, os.path.join(result_subdir, 'fakes%06d_sh.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) ### drawing fsg #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'fakes%06d_fsg_draw.png' % (cur_nimg // 1000))) ### Gen fft eval #gen_samples = sample_gen(Gs, fft_data_size).transpose(0, 2, 3, 1) #print(f'>>> fake_samples: max={np.amax(grid_fakes)} min={np.amin(grid_fakes)}') #print(f'>>> gen_samples: max={np.amax(gen_samples)} min={np.amin(gen_samples)}') #misc.save_image_grid(gen_samples[:25], os.path.join(result_subdir, 'fakes%06d_gsample.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #cosine_eval(gen_samples, f'gen_{cur_nimg//1000:06d}', freq_centers, log_dir=result_subdir, true_fft=true_fft, true_fft_hann=true_fft_hann, true_hist=true_hist) #fractal_eval(gen_samples, f'koch_snowflake_fakes{cur_nimg//1000:06d}', result_subdir) if cur_tick % network_snapshot_ticks == 0 or done: misc.save_pkl( (G, D, Gs), os.path.join( result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) # Record start time of the next tick. tick_start_time = time.time() # Write final results. misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) summary_log.close() open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()