def main(_): tf.gfile.MakeDirs(FLAGS.checkpoint_dir) model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size) logdir = os.path.join(FLAGS.checkpoint_dir, model_dir) tf.gfile.MakeDirs(logdir) graph = tf.Graph() with graph.as_default(): global_step = tf.train.create_global_step() devices = [ '/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers) ] noise_tensor = utils.make_z_normal(FLAGS.num_towers, FLAGS.batch_size, FLAGS.z_dim) model_object = model.SNGAN(noise_tensor=noise_tensor, config=FLAGS, global_step=global_step, devices=devices) train_ops = tfgan.GANTrainOps( generator_train_op=model_object.g_optim, discriminator_train_op=model_object.d_optim, global_step_inc_op=model_object.increment_global_step) session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) train_steps = tfgan.GANTrainSteps(1, 1) tfgan.gan_train(train_ops, get_hooks_fn=tfgan.get_sequential_train_hooks( train_steps=train_steps), hooks=([tf.train.StopAtStepHook(num_steps=2000000)]), logdir=logdir, master=FLAGS.master, is_chief=(FLAGS.task == 0), save_summaries_steps=FLAGS.save_summaries_steps, save_checkpoint_secs=FLAGS.save_checkpoint_secs, config=session_config)
def main(_, is_test=False): print('d_learning_rate', FLAGS.discriminator_learning_rate) print('g_learning_rate', FLAGS.generator_learning_rate) print('data_dir', FLAGS.data_dir) print(FLAGS.loss_type, FLAGS.batch_size, FLAGS.beta1) print('gf_df_dim', FLAGS.gf_dim, FLAGS.df_dim) print('Starting the program..') gfile.MakeDirs(FLAGS.checkpoint_dir) model_dir = '%s_%s' % ('celebA', FLAGS.batch_size) logdir = os.path.join(FLAGS.checkpoint_dir, model_dir) gfile.MakeDirs(logdir) graph = tf.Graph() with graph.as_default(): with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Instantiate global_step. global_step = tf.train.create_global_step() # Create model with FLAGS, global_step, and devices. devices = [ '/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers) ] # Create noise tensors zs = utils.make_z_normal(FLAGS.num_towers, FLAGS.batch_size, FLAGS.z_dim) print('save_summaries_steps', FLAGS.save_summaries_steps) dcgan = model.SNGAN(zs=zs, config=FLAGS, global_step=global_step, devices=devices) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Create sync_hooks when needed. if FLAGS.sync_replicas and FLAGS.num_workers > 1: print('condition 1') sync_hooks = [ dcgan.d_opt.make_session_run_hook(FLAGS.task == 0), dcgan.g_opt.make_session_run_hook(FLAGS.task == 0) ] else: print('condition 2') sync_hooks = [] train_ops = tfgan.GANTrainOps( generator_train_op=dcgan.g_optim, discriminator_train_op=dcgan.d_optim, global_step_inc_op=dcgan.increment_global_step) # We set allow_soft_placement to be True because Saver for the DCGAN model # gets misplaced on the GPU. session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) if is_test: return graph print("G step: ", FLAGS.g_step) print("D_step: ", FLAGS.d_step) train_steps = tfgan.GANTrainSteps(FLAGS.g_step, FLAGS.d_step) tfgan.gan_train( train_ops, get_hooks_fn=tfgan.get_sequential_train_hooks( train_steps=train_steps), hooks=([tf.train.StopAtStepHook(num_steps=2000000)] + sync_hooks), logdir=logdir, # master=FLAGS.master, # scaffold=scaffold, # load from google checkpoint is_chief=(FLAGS.task == 0), save_summaries_steps=FLAGS.save_summaries_steps, save_checkpoint_secs=FLAGS.save_checkpoint_secs, config=session_config)