Ejemplo n.º 1
0
def main(_):
    tf.gfile.MakeDirs(FLAGS.checkpoint_dir)
    model_dir = '%s_%s' % ('imagenet', FLAGS.batch_size)
    logdir = os.path.join(FLAGS.checkpoint_dir, model_dir)
    tf.gfile.MakeDirs(logdir)
    graph = tf.Graph()
    with graph.as_default():
        global_step = tf.train.create_global_step()
        devices = [
            '/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers)
        ]

        noise_tensor = utils.make_z_normal(FLAGS.num_towers, FLAGS.batch_size,
                                           FLAGS.z_dim)

        model_object = model.SNGAN(noise_tensor=noise_tensor,
                                   config=FLAGS,
                                   global_step=global_step,
                                   devices=devices)

        train_ops = tfgan.GANTrainOps(
            generator_train_op=model_object.g_optim,
            discriminator_train_op=model_object.d_optim,
            global_step_inc_op=model_object.increment_global_step)

        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        train_steps = tfgan.GANTrainSteps(1, 1)
        tfgan.gan_train(train_ops,
                        get_hooks_fn=tfgan.get_sequential_train_hooks(
                            train_steps=train_steps),
                        hooks=([tf.train.StopAtStepHook(num_steps=2000000)]),
                        logdir=logdir,
                        master=FLAGS.master,
                        is_chief=(FLAGS.task == 0),
                        save_summaries_steps=FLAGS.save_summaries_steps,
                        save_checkpoint_secs=FLAGS.save_checkpoint_secs,
                        config=session_config)
Ejemplo n.º 2
0
def main(_, is_test=False):
    print('d_learning_rate', FLAGS.discriminator_learning_rate)
    print('g_learning_rate', FLAGS.generator_learning_rate)
    print('data_dir', FLAGS.data_dir)
    print(FLAGS.loss_type, FLAGS.batch_size, FLAGS.beta1)
    print('gf_df_dim', FLAGS.gf_dim, FLAGS.df_dim)
    print('Starting the program..')
    gfile.MakeDirs(FLAGS.checkpoint_dir)

    model_dir = '%s_%s' % ('celebA', FLAGS.batch_size)
    logdir = os.path.join(FLAGS.checkpoint_dir, model_dir)
    gfile.MakeDirs(logdir)

    graph = tf.Graph()
    with graph.as_default():

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Instantiate global_step.
            global_step = tf.train.create_global_step()

        # Create model with FLAGS, global_step, and devices.
        devices = [
            '/gpu:{}'.format(tower) for tower in range(FLAGS.num_towers)
        ]

        # Create noise tensors
        zs = utils.make_z_normal(FLAGS.num_towers, FLAGS.batch_size,
                                 FLAGS.z_dim)

        print('save_summaries_steps', FLAGS.save_summaries_steps)

        dcgan = model.SNGAN(zs=zs,
                            config=FLAGS,
                            global_step=global_step,
                            devices=devices)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Create sync_hooks when needed.
            if FLAGS.sync_replicas and FLAGS.num_workers > 1:
                print('condition 1')
                sync_hooks = [
                    dcgan.d_opt.make_session_run_hook(FLAGS.task == 0),
                    dcgan.g_opt.make_session_run_hook(FLAGS.task == 0)
                ]
            else:
                print('condition 2')
                sync_hooks = []

        train_ops = tfgan.GANTrainOps(
            generator_train_op=dcgan.g_optim,
            discriminator_train_op=dcgan.d_optim,
            global_step_inc_op=dcgan.increment_global_step)

        # We set allow_soft_placement to be True because Saver for the DCGAN model
        # gets misplaced on the GPU.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        if is_test:
            return graph

        print("G step: ", FLAGS.g_step)
        print("D_step: ", FLAGS.d_step)
        train_steps = tfgan.GANTrainSteps(FLAGS.g_step, FLAGS.d_step)

        tfgan.gan_train(
            train_ops,
            get_hooks_fn=tfgan.get_sequential_train_hooks(
                train_steps=train_steps),
            hooks=([tf.train.StopAtStepHook(num_steps=2000000)] + sync_hooks),
            logdir=logdir,
            # master=FLAGS.master,
            # scaffold=scaffold, # load from google checkpoint
            is_chief=(FLAGS.task == 0),
            save_summaries_steps=FLAGS.save_summaries_steps,
            save_checkpoint_secs=FLAGS.save_checkpoint_secs,
            config=session_config)