Exemplo n.º 1
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    dataset = np_dataset.build_np_dataset(root=config.h5root, batch_size=config.batch_size,
                                          gpu_nums=config.gpu_nums, load_in_mem=config.load_in_mem,
                                          load_num=config.load_num)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        # dataset = get_dataset(name=config.dataset,
        #                       seed=config.seed).train_input_fn(params=data_iter_params)
        # dataset = strategy.experimental_distribute_dataset(dataset)
        # data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='vgg_alter')
        learning_rate = tf.train.exponential_decay(0.0001, global_step, 150000 / config.gpu_nums,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta1=config.beta1)

        print("Building tensorflow graph...")

        def train_step(image, W):
            E_w = Encoder(image, training=True)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(E_w - W))
                e_loss = recon_loss_pixel

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss)
        e_loss = compute_loss(train_step, dataset.get_next(), strategy)
        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer]
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            if config.finalize:
                sess.graph.finalize()
            timer.update()
            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, lr_ = sess.run([e_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, lr_, timer.runing_time_format))
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/vgg.ckpt',
                                 global_step=iteration, write_meta_graph=False)
Exemplo n.º 2
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset, fixed_img = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    ssgan_sample_root = config.h5root.replace('ILSVRC128', 'SSGAN128')
    sample_dset = npdt.build_np_dataset(root=ssgan_sample_root, batch_size=config.batch_size,
                                        gpu_nums=config.gpu_nums, load_in_mem=config.load_in_mem,
                                        load_num=config.load_num)
    sample_dset = strategy.experimental_distribute_dataset(sample_dset)
    sample_dset = sample_dset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        # dataset = get_dataset(name=config.dataset,
        #                       seed=config.seed).train_input_fn(params=data_iter_params)
        # dataset = strategy.experimental_distribute_dataset(dataset)
        # data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='Encoder')
        Generator = resnet_biggan.Generator(image_shape=[128, 128, 3], embed_y=False,
                                            embed_z=False,
                                            batch_norm_fn=arch_ops.self_modulated_batch_norm,
                                            spectral_norm=True)
        learning_rate = tf.train.exponential_decay(config.lr, global_step, 150000 / config.gpu_nums,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)

        print("Building tensorflow graph...")
        def train_step(image, sample_img, sample_w):

            w = Encoder(image, training=True)
            x = Generator(w, y=None, is_training=True)
            ww_ = Encoder(sample_img, training=True)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(x - image))
                sample_loss = tf.reduce_mean(tf.square(ww_ - sample_w))
                e_loss = recon_loss_pixel + sample_loss * config.g_loss_scale

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss), tf.identity(recon_loss_pixel), tf.identity(sample_loss)
        e_loss, r_loss, s_loss = compute_loss(train_step, dataset.get_next(), sample_dset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            # def eval_fn():
            fixed_w = Encoder(fixed_x, training=False)
            fixed_sample = Generator(z=fixed_w, y=None, is_training=True)
                # return fixed_sample
            # fixed_sample = strategy.experimental_run_v2(eval_fn, ())

        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer, sample_dset.initializer]
            restore_g = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'generator' in v.name]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator...")
            saver_g.restore(sess, config.restore_g_dir)
            save_image_grid(fixed_img, filename=config.model_dir + '/reals.png')
            timer.update()

            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, r_loss_, s_loss_, lr_ = sess.run(
                    [e_loss, r_loss, s_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, r_loss %f, s_loss %f, "
                          "learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, r_loss_, s_loss_,
                           lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                                 global_step=iteration, write_meta_graph=False)
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset, fixed_img = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums,
                                          load_in_mem=config.load_in_mem)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        print("Constructing networks...")
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        Encoder = ImagenetModel(resnet_size=50, num_classes=None, name='vgg_alter')
        Assgin_net = assgin_net(x0_ch=512+256, scope='Assgin')
        BN_net = BNlayer(scope='Ebn', z0_ch=1536*16)
        Generator = resnet_biggan_ssgan.Generator(image_shape=[128, 128, 3], embed_y=False,
                                                  embed_z=False,
                                                  batch_norm_fn=arch_ops.self_modulated_batch_norm,
                                                  spectral_norm=True)
        learning_rate = tf.train.exponential_decay(config.lr, global_step, config.lr_decay_step,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)
        G_embed_np = np.load('/ghome/fengrl/ssgan/invSSGAN/G_embed.npy')
        G_embed = tf.convert_to_tensor(G_embed_np, dtype=tf.float32, name='G_embed')
        print("Building tensorflow graph...")

        def train_step(image):
            sample_z = tf.random.normal([config.batch_size // config.gpu_nums, config.dim_z],
                                        stddev=1.0, name='sample_z')
            sample_w = tf.matmul(sample_z, G_embed, name='sample_w')
            sample_img, sample_w_out = Generator(sample_w, y=None, is_training=True)
            ww_ = Encoder(sample_img, training=True)
            ww_ = Assgin_net(ww_)
            ww_ = BN_net(ww_, is_training=True)

            w = Encoder(image, training=True)
            w = Assgin_net(w)
            w = BN_net(w, is_training=True)
            x, _ = Generator(w, y=None, is_training=True)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(x - image))
                sample_loss = tf.reduce_mean(tf.square(ww_[:, :1536*16] - sample_w_out[:, :1536*16])) * 0.7
                sample_loss += tf.reduce_mean(tf.square(ww_[:, 1536*16:] - sample_w_out[:, 1536*16:])) * 0.3
                e_loss = recon_loss_pixel + sample_loss * config.s_loss_scale

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss,
                                          var_list=Encoder.trainable_variables + Assgin_net.trainable_variables + BN_net.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss), tf.identity(recon_loss_pixel), tf.identity(sample_loss)
        e_loss, r_loss, s_loss = compute_loss(train_step, dataset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            fixed_w = Encoder(fixed_x, training=True)
            fixed_w = Assgin_net(fixed_w)
            fixed_w = BN_net(fixed_w, is_training=True)
            fixed_sample, _ = Generator(z=fixed_w, y=None, is_training=True)

        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer]
            restore_g = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'generator' in v.name]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
            saver_assgin = tf.train.Saver(Assgin_net.trainable_variables, restore_sequentially=True)
            saver_ebn = tf.train.Saver(BN_net.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator...")
            saver_g.restore(sess, config.restore_g_dir)
            if config.resume:
                saver_e.restore(sess, config.restore_v_dir)
                if config.resume_assgin:
                    saver_assgin.restore(sess, config.restore_assgin_dir)
                    if config.resume_ebn:
                        saver_ebn.restore(sess, config.restore_ebn_dir)
            save_image_grid(fixed_img, filename=config.model_dir + '/reals.png')
            timer.update()

            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, r_loss_, s_loss_, lr_ = sess.run(
                    [e_loss, r_loss, s_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, r_loss %f, s_loss %f, "
                          "learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, r_loss_, s_loss_,
                           lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                                 global_step=iteration, write_meta_graph=False)
                    saver_assgin.save(sess, save_path=config.model_dir + '/assgin.ckpt',
                                      global_step=iteration, write_meta_graph=False)
                    saver_ebn.save(sess, save_path=config.model_dir + '/bn.ckpt',
                                  global_step=iteration, write_meta_graph=False)
Exemplo n.º 4
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset, _, fixed_img = datasets.build_data_input_pipeline_from_hdf5(
        root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums,
        load_in_mem=config.load_in_mem, labeled_per_class=config.labeled_per_class, save_index_dir=config.model_dir)
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()
    eval_dset = datasets.build_eval_dset(config.eval_h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums)
    eval_dset = strategy.experimental_distribute_dataset(eval_dset)
    eval_dset = eval_dset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        print("Constructing networks...")
        Encoder = ImagenetModel(resnet_size=50, num_classes=None, name='vgg_alter')
        Dense = tf.layers.Dense(1000, name='Final_dense')
        learning_rate = tf.train.exponential_decay(config.lr, global_step, 60000,
                                                   config.lr_decay_coef, staircase=False)
        Dense_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)
        print("Building tensorflow graph...")

        def train_step(image, label):
            w = Encoder(image, training=True)
            w = Dense(w)
            label = tf.one_hot(label, 1000)
            loss = tf.nn.softmax_cross_entropy_with_logits_v2(label, w)
            loss = tf.reduce_mean(loss)

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                Dense_opt = Dense_solver.minimize(loss, var_list=Dense.trainable_variables)
                with tf.control_dependencies([Dense_opt]):
                    return tf.identity(loss)
        loss_run = strategy.experimental_run_v2(train_step, dataset.get_next())
        loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_run, axis=None)
        print("Building eval module...")

        def eval_step(image, label):
            w = Encoder(image, training=True)
            w = Dense(w)
            p = tf.math.argmax(w, 1)
            p = tf.cast(p, tf.int32)
            precise = tf.reduce_mean(tf.cast(tf.equal(p, label), tf.float32))
            return precise

        precise = strategy.experimental_run_v2(eval_step, eval_dset.get_next())
        precise = strategy.reduce(tf.distribute.ReduceOp.MEAN, precise, axis=None)
        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer, eval_dset.initializer]
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
            saver_dense = tf.train.Saver(Dense.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore Encoder...")
            saver_e.restore(sess, config.restore_v_dir)
            timer.update()
            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)
            print("Start iterations...")
            for iteration in range(config.total_step):
                loss_, lr_ = sess.run([loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, loss %f, learning_rate % f, consuming time %s" %
                          (iteration, loss_, lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    print('Starting eval...')
                    precise_ = 0.0
                    eval_iters = 50000 // config.batch_size
                    for _ in range(2 * eval_iters):
                        precise_ += sess.run(precise)
                    precise_ = precise_ / (2 * eval_iters)
                    timer.update()
                    print('Eval consuming time %s' % timer.duration_format)
                    print('step %d, precision %f in eval dataset of length %d' %
                          (iteration, precise_, eval_iters * config.batch_size))
                if iteration % config.save_per_steps == 0:
                    saver_dense.save(sess, save_path=config.model_dir + '/dense.ckpt',
                                     global_step=iteration, write_meta_graph=False)
Exemplo n.º 5
0
                    type=str,
                    default='vae_alter',
                    help='seed for np')
parser.add_argument('--save_name',
                    type=str,
                    default='vae-16000',
                    help='seed for np')
parser.add_argument('--embed_dir',
                    type=str,
                    default='/ghome/fengrl/ssgan/invSSGAN/G_embed.npy',
                    help='seed for np')

args = parser.parse_args()

Encoder = ImagenetModel(resnet_size=50,
                        num_classes=None,
                        name=args.encoder_name)
Generator = resnet_biggan_ssgan.Generator(
    image_shape=[128, 128, 3],
    embed_y=False,
    embed_z=False,
    batch_norm_fn=arch_ops.self_modulated_batch_norm,
    spectral_norm=True)
Assgin_net = assgin_net(x0_ch=512 + 256, scope='Assgin')
BN_net = BNlayer(scope='Ebn', z0_ch=1536 * 16)
G_embed_np = np.load(args.embed_dir)
G_embed = tf.convert_to_tensor(G_embed_np, dtype=tf.float32, name='G_embed')

z = tf.random.normal([args.batch_size, 120], stddev=1.0, name='z')
wz = tf.matmul(z, G_embed)
index = np.arange(1000000)
Exemplo n.º 6
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    strategy = tf.distribute.MirroredStrategy()
    print('Loading Imagenet2012 dataset...')
    # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size)
    dataset = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums)
    dataset = dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False,
                                      aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        # dataset = get_dataset(name=config.dataset,
        #                       seed=config.seed).train_input_fn(params=data_iter_params)
        # dataset = strategy.experimental_distribute_dataset(dataset)
        # data_iter = dataset.make_initializable_iterator()
        print("Constructing networks...")
        # img = tf.placeholder(tf.float32, [None, 128, 128, 3])
        fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        img = dataset.get_next()
        Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='Encoder')
        VGG_alter = ImagenetModel(resnet_size=50, num_classes=120, name='vgg_alter')
        Generator = resnet_biggan.Generator(image_shape=[128, 128, 3], embed_y=False,
                                            embed_z=False,
                                            batch_norm_fn=arch_ops.self_modulated_batch_norm,
                                            spectral_norm=True)
        Discriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False)
        learning_rate = tf.train.exponential_decay(0.0001, global_step, 150000 / config.gpu_nums,
                                                   0.8, staircase=False)
        E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2)
        # D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate * 5, name='d_opt', beta1=config.beta1)

        print("Building tensorflow graph...")
        def train_step(image):

            w = Encoder(image, training=True)
            x = Generator(w, y=None, is_training=True)
            # _, real_logits, _ = Discriminator(img, y=None, is_training=True)
            _, fake_logits, _ = Discriminator(x, y=None, is_training=True)
            # real_logits = fp32(real_logits)
            fake_logits = fp32(fake_logits)
            with tf.variable_scope('recon_loss'):
                recon_loss_pixel = tf.reduce_mean(tf.square(x - image))
                adv_loss = tf.reduce_mean(tf.nn.softplus(-fake_logits)) * config.g_loss_scale
                vgg_real = VGG_alter(image, training=True)
                vgg_fake = VGG_alter(x, training=True)
                feature_scale = tf.cast(tf.reduce_prod(vgg_real.shape[1:]), dtype=tf.float32)
                vgg_loss = config.r_loss_scale * tf.nn.l2_loss(vgg_fake - vgg_real) / (config.batch_size * feature_scale)
                e_loss = recon_loss_pixel + adv_loss + vgg_loss
            # with tf.variable_scope('d_loss'):
            #     d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - real_logits))
            #     d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + fake_logits))
            #     d_loss = d_loss_real + d_loss_fake

            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables)
                with tf.control_dependencies([E_opt]):
                    return tf.identity(e_loss)
        e_loss = compute_loss(train_step, dataset.get_next(), strategy)
        print("Building eval module...")
        with tf.init_scope():
            fixed_w = Encoder(fixed_x, training=False)
            fixed_sample = Generator(z=fixed_w, y=None, is_training=False)
        print('Building init module...')
        with tf.init_scope():
            init = [tf.global_variables_initializer(), dataset.initializer]
            restore_g = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'generator' in v.name]
            restore_d = [v for v in tf.global_variables() if 'opt' not in v.name
                         and 'beta1_power' not in v.name
                         and 'beta2_power' not in v.name
                         and 'discriminator' in v.name]
            saver_g = tf.train.Saver(restore_g, restore_sequentially=True)
            saver_d = tf.train.Saver(restore_d, restore_sequentially=True)
            saver_v = tf.train.Saver(VGG_alter.trainable_variables, restore_sequentially=True)
            saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True)
        print("Start training...")
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(init)
            print("Restore generator and discriminator...")
            saver_g.restore(sess, config.restore_g_dir)
            saver_d.restore(sess, config.restore_d_dir)
            saver_v.restore(sess, config.restore_v_dir)
            timer.update()
            fixed_img = sess.run(dataset.get_next())
            save_image_grid(fixed_img, filename=config.model_dir + '/reals.png')
            print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)

            print("Start iterations...")
            for iteration in range(config.total_step):
                e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_, lr_ = sess.run(
                    [e_loss, adv_loss, recon_loss_pixel, vgg_loss, learning_rate])
                if iteration % config.print_loss_per_steps == 0:
                    timer.update()
                    print("step %d, e_loss %f, adv_loss %f, recon_loss_pixel %f, vgg_loss %f, "
                          "learning_rate % f, consuming time %s" %
                          (iteration, e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_,
                           lr_, timer.runing_time_format))
                if iteration % config.eval_per_steps == 0:
                    timer.update()
                    fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img})
                    save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration)
                if iteration % config.save_per_steps == 0:
                    saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                                 global_step=iteration, write_meta_graph=False)