示例#1
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset, fixed_img = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    ssgan_sample_root = config.h5root.replace('ILSVRC128', 'SSGAN128')
    sample_dset = npdt.build_np_dataset(root=ssgan_sample_root, batch_size=config.batch_size,
                                        gpu_nums=config.gpu_nums, load_in_mem=config.load_in_mem,
                                        load_num=config.load_num)
    sample_dset = strategy.experimental_distribute_dataset(sample_dset)
    sample_dset = sample_dset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        # dataset = get_dataset(name=config.dataset,
        #                       seed=config.seed).train_input_fn(params=data_iter_params)
        # dataset = strategy.experimental_distribute_dataset(dataset)
        # data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='Encoder')
        Generator = resnet_biggan.Generator(image_shape=[128, 128, 3], embed_y=False,
                                            embed_z=False,
                                            batch_norm_fn=arch_ops.self_modulated_batch_norm,
                                            spectral_norm=True)
        learning_rate = tf.train.exponential_decay(config.lr, global_step, 150000 / config.gpu_nums,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)

        print("Building tensorflow graph...")
        def train_step(image, sample_img, sample_w):

            w = Encoder(image, training=True)
            x = Generator(w, y=None, is_training=True)
            ww_ = Encoder(sample_img, training=True)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(x - image))
                sample_loss = tf.reduce_mean(tf.square(ww_ - sample_w))
                e_loss = recon_loss_pixel + sample_loss * config.g_loss_scale

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss), tf.identity(recon_loss_pixel), tf.identity(sample_loss)
        e_loss, r_loss, s_loss = compute_loss(train_step, dataset.get_next(), sample_dset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            # def eval_fn():
            fixed_w = Encoder(fixed_x, training=False)
            fixed_sample = Generator(z=fixed_w, y=None, is_training=True)
                # return fixed_sample
            # fixed_sample = strategy.experimental_run_v2(eval_fn, ())

        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer, sample_dset.initializer]
            restore_g = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'generator' in v.name]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator...")
            saver_g.restore(sess, config.restore_g_dir)
            save_image_grid(fixed_img, filename=config.model_dir + '/reals.png')
            timer.update()

            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, r_loss_, s_loss_, lr_ = sess.run(
                    [e_loss, r_loss, s_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, r_loss %f, s_loss %f, "
                          "learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, r_loss_, s_loss_,
                           lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                                 global_step=iteration, write_meta_graph=False)
示例#2
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    data_iter_params = {"batch_size": config.batch_size, "seed": config.seed}
    with strategy.scope():
        # ema = tf.train.ExponentialMovingAverage(0.999)
        global_step = tf.get_variable(
            name='global_step',
            initializer=tf.constant(0),
            trainable=False,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        dataset = get_dataset(
            name=config.dataset,
            seed=config.seed).train_input_fn(params=data_iter_params)
        dataset = strategy.experimental_distribute_dataset(dataset)
        data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        InvMap = invert.InvMap(latent_size=config.dim_z)
        Generator = resnet_biggan.Generator(
            image_shape=[128, 128, 3],
            embed_y=False,
            embed_z=True,
            batch_norm_fn=arch_ops.self_modulated_batch_norm,
            spectral_norm=True)
        Discriminator = resnet_biggan.Discriminator(spectral_norm=True,
                                                    project_y=False)
        I_opt = tf.train.AdamOptimizer(learning_rate=0.0005,
                                       name='i_opt',
                                       beta1=0.0,
                                       beta2=0.999)
        G_opt = tf.train.AdamOptimizer(learning_rate=0.00001,
                                       name='g_opt',
                                       beta1=0.0,
                                       beta2=0.999)
        D_opt = tf.train.AdamOptimizer(learning_rate=0.00005,
                                       name='d_opt',
                                       beta1=0.0,
                                       beta2=0.999)
        train_z = tf.random.normal(
            [config.batch_size // config.gpu_nums, config.dim_z],
            stddev=1.0,
            name='train_z')
        # eval_z = tf.random.uniform([config.batch_size // config.gpu_nums, config.dim_z],
        #                             minval=-1.0, maxval=1.0, name='eval_z')
        # eval_z = tf.placeholder(tf.float32, name='eval_z')
        fixed_sample_z = tf.placeholder(tf.float32, name='fixed_sample_z')

        print("Building tensorflow graph...")

        def train_step(training_who="G", step=None, z=None, data=None):
            img, labels = data
            w = InvMap(z)
            samples = Generator(z=w, y=None, is_training=True)
            d_real, d_real_logits, _ = Discriminator(x=img,
                                                     y=None,
                                                     is_training=True)
            d_fake, d_fake_logits, _ = Discriminator(x=samples,
                                                     y=None,
                                                     is_training=True)
            d_loss, _, _, g_loss = loss_lib.get_losses(
                d_real=d_real,
                d_fake=d_fake,
                d_real_logits=d_real_logits,
                d_fake_logits=d_fake_logits)

            inception_score = tfmetric.call_metric(run_dir_root=config.run_dir,
                                                   name="is",
                                                   images=samples)
            fid = tfmetric.call_metric(run_dir_root=config.run_dir,
                                       name="fid",
                                       reals=img,
                                       fakes=samples)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                if training_who == "G":
                    train_op = tf.group(
                        G_opt.minimize(g_loss,
                                       var_list=Generator.trainable_variables,
                                       global_step=step),
                        I_opt.minimize(g_loss,
                                       var_list=InvMap.trainable_variables,
                                       global_step=step))
                    # decay = config.ema_decay * tf.cast(
                    #     tf.greater_equal(step, config.ema_start_step), tf.float32)
                    # with tf.variable_scope('', reuse=tf.AUTO_REUSE):
                    #     ema = tf.train.ExponentialMovingAverage(decay=decay)
                    #     with tf.control_dependencies([train_op]):
                    #         train_op = ema.apply(Generator.trainable_variables + InvMap.trainable_variables)

                    with tf.control_dependencies([train_op]):
                        return tf.identity(g_loss), inception_score, fid
                else:
                    train_op = D_opt.minimize(
                        d_loss,
                        var_list=Discriminator.trainable_variables,
                        global_step=step)
                    with tf.control_dependencies([train_op]):
                        return tf.identity(d_loss), inception_score, fid

        # def eval_step(z, data=None):
        #     img, _ = data
        #     # with tf.variable_scope('', reuse=tf.AUTO_REUSE):
        #     #     ema = tf.train.ExponentialMovingAverage(decay=0.999)
        #     #     ema.apply(Generator.trainable_variables + InvMap.trainable_variables)
        #     #
        #     # def ema_getter(getter, name, *args, **kwargs):
        #     #     var = getter(name, *args, **kwargs)
        #     #     ema_var = ema.average(var)
        #     #     if ema_var is None:
        #     #         var_names_without_ema = {"u_var", "accu_mean", "accu_variance",
        #     #                                  "accu_counter", "update_accus"}
        #     #         if name.split("/")[-1] not in var_names_without_ema:
        #     #             logging.warning("Could not find EMA variable for %s.", name)
        #     #         return var
        #     #     return ema_var
        #     # with tf.variable_scope("", values=[z, img], reuse=tf.AUTO_REUSE,
        #     #                        custom_getter=ema_getter):
        #     w = InvMap(z)
        #     sampled = Generator(z=w, y=None, is_training=False)
        #     inception_score = tfmetric.call_metric(run_dir_root=config.run_dir,
        #                                            name="is",
        #                                            images=sampled)
        #     fid = tfmetric.call_metric(run_dir_root=config.run_dir,
        #                                name="fid",
        #                                reals=img,
        #                                fakes=sampled)
        #     return inception_score, fid, sampled

        g_loss, d_loss, IS, FID = compute_loss(train_step, strategy,
                                               global_step, train_z, data_iter)
        print("Building eval module...")
        with tf.init_scope():
            # IS, FID, eval_sample = compute_eval(eval_step, strategy, eval_z, data_iter)
            fixed_sample_w = InvMap(fixed_sample_z)
            eval_sample = Generator(z=fixed_sample_w,
                                    y=None,
                                    is_training=False)
        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), data_iter.initializer]
            restore_g = [
                v for v in tf.global_variables()
                if 'opt' not in v.name and 'beta1_power' not in v.name
                and 'beta2_power' not in v.name and 'generator' in v.name
            ]
            restore_d = [
                v for v in tf.global_variables()
                if 'opt' not in v.name and 'beta1_power' not in v.name
                and 'beta2_power' not in v.name and 'discriminator' in v.name
            ]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_d = tf.train.Saver(restore_d, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)
            fixed_z = np.random.uniform(
                low=-1.0,
                high=1.0,
                size=[config.batch_size * 2 // config.gpu_nums, config.dim_z])
            print("Restore generator and discriminator...")
            saver_g.restore(sess, '/ghome/fengrl/gen_ckpt/gen-0')
            saver_d.restore(sess, '/ghome/fengrl/disc_ckpt/disc-0')
            print("Start iterations...")
            for iteration in range(config.total_step):
                for D_repeat in range(config.disc_iter):
                    D_loss = sess.run(d_loss)
                G_loss = sess.run(g_loss)
                if iteration % config.print_loss_per_steps == 0:
                    print("step %d, G_loss %f, D_loss %f" %
                          (iteration, G_loss, D_loss))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_sample = sess.run(eval_sample,
                                            {fixed_sample_z: fixed_z})
                    save_image_grid(fixed_sample,
                                    filename=config.model_dir +
                                    '/fakes%06d.png' % iteration)
                    is_eval, fid_eval = sess.run([IS, FID])
                    print(
                        "Time %s, fid %f, inception_score %f , G_loss %f, D_loss %f, step %d"
                        % (timer.runing_time, fid_eval, is_eval, G_loss,
                           D_loss, iteration))
                if iteration % config.save_per_steps == 0:
                    saver_g.save(sess,
                                 save_path=config.model_dir + '/gen.ckpt',
                                 global_step=iteration,
                                 write_meta_graph=False)
                    saver_d.save(sess,
                                 save_path=config.model_dir + '/disc.ckpt',
                                 global_step=iteration,
                                 write_meta_graph=False)
示例#3
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset, fixed_img = build_np_dataset(root=config.h5root,
                                          batch_size=config.batch_size,
                                          gpu_nums=config.gpu_nums,
                                          load_in_mem=config.load_in_mem)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(
            name='global_step',
            initializer=tf.constant(0),
            trainable=False,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        print("Constructing networks...")
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        Generator = resnet_biggan.Generator(
            image_shape=[128, 128, 3],
            embed_y=False,
            embed_z=False,
            batch_norm_fn=arch_ops.self_modulated_batch_norm,
            spectral_norm=True)
        Discriminator = resnet_biggan.Discriminator(spectral_norm=True,
                                                    project_y=False)
        # Despite Z_embed is out of Generator, it is viewed as part of Generator
        Z_embed = dense(120, False, name='embed_z', scope=Generator.name)
        D_embed = dense(120, True, name='embed_d', scope='Embed_D')
        learning_rate = tf.train.exponential_decay(config.lr,
                                                   global_step,
                                                   60000,
                                                   0.8,
                                                   staircase=False)
        Embed_solver = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                              name='d_opt',
                                              beta1=0.0,
                                              beta2=config.beta2)
        print("Building tensorflow graph...")

        def train_step(image):
            z = tf.random.normal(
                [config.batch_size // config.gpu_nums, config.dim_z],
                stddev=1.0,
                name='sample_z')
            w = Z_embed(z)
            fake = Generator(w, y=None, is_training=True)
            fake_out, fake_logits, fake_h = Discriminator(x=fake,
                                                          y=None,
                                                          is_training=True)
            real_out, real_logits, real_h = Discriminator(x=image,
                                                          y=None,
                                                          is_training=True)
            fake_w = D_embed(fake_h)
            real_w = D_embed(real_h)
            # x is the reconstruction of image
            x = Generator(real_w, None, True)
            _, real_logits_, real_h_ = Discriminator(x, None, True)
            d_loss = tf.reduce_mean(tf.nn.relu(1.0 - real_logits_))
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(real_h - real_h_))
                sample_loss = tf.reduce_mean(
                    tf.square(w - fake_w)) * config.s_loss_scale
            final_loss = d_loss + sample_loss * config.alpha + recon_loss_pixel * config.beta
            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                Embed_opt = Embed_solver.minimize(
                    final_loss, var_list=D_embed.trainable_variables)
                with tf.control_dependencies([Embed_opt]):
                    return tf.identity(final_loss), tf.identity(d_loss), tf.identity(recon_loss_pixel),\
                           tf.identity(sample_loss)

        final_loss, d_loss, r_loss, s_loss = compute_loss(
            train_step, dataset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            _, _, fixed_h = Discriminator(fixed_x, None, True)
            fixed_w = D_embed(fixed_h)
            fixed_sample = Generator(z=fixed_w, y=None, is_training=True)

        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer]
            restore_g = [
                v for v in tf.global_variables()
                if 'opt' not in v.name and 'beta1_power' not in v.name
                and 'beta2_power' not in v.name and 'generator' in v.name
            ]
            restore_d = [
                v for v in tf.global_variables()
                if 'opt' not in v.name and 'beta1_power' not in v.name
                and 'beta2_power' not in v.name and 'discriminator' in v.name
            ]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_d = tf.train.Saver(restore_d, restore_sequentially=True)
            saver_embed = tf.train.Saver(var_list=D_embed.trainable_variables)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator...")
            saver_g.restore(sess, config.restore_g_dir)
            saver_d.restore(sess, config.restore_d_dir)
            save_image_grid(fixed_img,
                            filename=config.model_dir + '/reals.png')
            timer.update()

            print("Completing all work, iteration now start, consuming %s " %
                  timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                final_loss_, d_loss_, r_loss_, s_loss_, lr_ = sess.run(
                    [final_loss, d_loss, r_loss, s_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print(
                        "step %d, final_loss %f, d_loss %f, r_loss %f, s_loss %f, "
                        "learning_rate % f, consuming time %s" %
                        (iteration, final_loss_, d_loss_, r_loss_, s_loss_,
                         lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_,
                                    filename=config.model_dir +
                                    '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_embed.save(sess,
                                     save_path=config.model_dir +
                                     '/embed.ckpt',
                                     global_step=iteration,
                                     write_meta_graph=False)
示例#4
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums)
    dataset = dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        # dataset = get_dataset(name=config.dataset,
        #                       seed=config.seed).train_input_fn(params=data_iter_params)
        # dataset = strategy.experimental_distribute_dataset(dataset)
        # data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        # img = tf.placeholder(tf.float32, [None, 128, 128, 3])
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        img = dataset.get_next()
        Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='Encoder')
        VGG_alter = ImagenetModel(resnet_size=50, num_classes=120, name='vgg_alter')
        Generator = resnet_biggan.Generator(image_shape=[128, 128, 3], embed_y=False,
                                            embed_z=False,
                                            batch_norm_fn=arch_ops.self_modulated_batch_norm,
                                            spectral_norm=True)
        Discriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False)
        learning_rate = tf.train.exponential_decay(0.0001, global_step, 150000 / config.gpu_nums,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)
        # D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate * 5, name='d_opt', beta1=config.beta1)

        print("Building tensorflow graph...")
        def train_step(image):

            w = Encoder(image, training=True)
            x = Generator(w, y=None, is_training=True)
            # _, real_logits, _ = Discriminator(img, y=None, is_training=True)
            _, fake_logits, _ = Discriminator(x, y=None, is_training=True)
            # real_logits = fp32(real_logits)
            fake_logits = fp32(fake_logits)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(x - image))
                adv_loss = tf.reduce_mean(tf.nn.softplus(-fake_logits)) * config.g_loss_scale
                vgg_real = VGG_alter(image, training=True)
                vgg_fake = VGG_alter(x, training=True)
                feature_scale = tf.cast(tf.reduce_prod(vgg_real.shape[1:]), dtype=tf.float32)
                vgg_loss = config.r_loss_scale * tf.nn.l2_loss(vgg_fake - vgg_real) / (config.batch_size * feature_scale)
                e_loss = recon_loss_pixel + adv_loss + vgg_loss
            # with tf.variable_scope('d_loss'):
            #     d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - real_logits))
            #     d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + fake_logits))
            #     d_loss = d_loss_real + d_loss_fake

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss)
        e_loss = compute_loss(train_step, dataset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            fixed_w = Encoder(fixed_x, training=False)
            fixed_sample = Generator(z=fixed_w, y=None, is_training=False)
        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer]
            restore_g = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'generator' in v.name]
            restore_d = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'discriminator' in v.name]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_d = tf.train.Saver(restore_d, restore_sequentially=True)
            saver_v = tf.train.Saver(VGG_alter.trainable_variables, restore_sequentially=True)
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator and discriminator...")
            saver_g.restore(sess, config.restore_g_dir)
            saver_d.restore(sess, config.restore_d_dir)
            saver_v.restore(sess, config.restore_v_dir)
            timer.update()
            fixed_img = sess.run(dataset.get_next())
            save_image_grid(fixed_img, filename=config.model_dir + '/reals.png')
            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_, lr_ = sess.run(
                    [e_loss, adv_loss, recon_loss_pixel, vgg_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, adv_loss %f, recon_loss_pixel %f, vgg_loss %f, "
                          "learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_,
                           lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                                 global_step=iteration, write_meta_graph=False)