Exemple #1
0
def main(config):
    print('Loading %s dataset...' % config['dataset_name'])
    dset = get_dataset(config['dataset_name'], '/gdata/tfds', 2)
    dataset = dset.input_fn(config['batch_size'], mode='train')
    dataset = dataset.make_initializable_iterator()
    Encoder = nn.Encoder(config['dim_z'], exceptions=['opt'], name='Encoder')
    image, label = dataset.get_next()
    _, _, z = Encoder(image, is_training=True)
    saver = tf.train.Saver(Encoder.restore_variables)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run([tf.global_variables_initializer(), dataset.initializer])
        print("Restore Encoder...")
        saver.restore(sess, config['model_dir'] + '/en.ckpt-248000')
        print('Generate embeddings...')

        f = open(config['model_dir'] + '/embeddings.tsv', 'wt')
        f_writer = csv.writer(f, delimiter='\t')
        g = open(config['model_dir'] + '/labels.tsv', 'wt')
        g_writer = csv.writer(g, delimiter='\t')

        for _ in tqdm(range(config['total_step'])):
            z_, l_ = sess.run([z, label])
            for row in z_:
                f_writer.writerow(row)
            for row in l_:
                g_writer.writerow([row])
        f.close()
        g.close()
Exemple #2
0
def training_loop(config: Config):
    timer = Timer()
    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dset = get_dataset(config.dataset_name, config.tfds_dir,
                       config.gpu_nums * 2)
    dataset = train_input_fn(dset, config.batch_size)
    dataset = dataset.make_initializable_iterator()
    print("Constructing networks...")
    Encoder = vae.Encoder(config.dim_z, exceptions=['opt'], name='Encoder')
    Decoder = vae.Decoder(dset.image_shape, exceptions=['opt'], name='Decoder')
    print("Building tensorflow graph...")
    image, label = dataset.get_next()
    _, _, z = Encoder(image, is_training=True)
    sigma2_plus = compute_sigma2(z)
    print("Building eval module...")

    fixed_z = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z0 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z1 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32,
                             (config.example_nums, ) + dset.image_shape)
    fixed_x0 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + dset.image_shape)
    fixed_x1 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + dset.image_shape)
    input_dict = {
        'fixed_z': fixed_z,
        'fixed_z0': fixed_z0,
        'fixed_z1': fixed_z1,
        'fixed_x': fixed_x,
        'fixed_x0': fixed_x0,
        'fixed_x1': fixed_x1,
        'num_midpoints': config.num_midpoints
    }

    def sample_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1
        })
        return out_dict

    o_dict = sample_step()

    print("Building init module...")
    with tf.init_scope():
        init = [tf.global_variables_initializer(), dataset.initializer]
        saver_e = tf.train.Saver(Encoder.restore_variables)
        saver_d = tf.train.Saver(Decoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing sample utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                  config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        o_dict_ = sess.run(o_dict, {
            fixed_x: fixed_x_,
            fixed_x0: fixed_x0_,
            fixed_x1: fixed_x1_
        })
        for key in o_dict:
            if o_dict_[key].ndim == 5:
                img = o_dict_[key].transpose([0, 1, 4, 2, 3])
            else:
                img = o_dict_[key].transpose([0, 3, 1, 2])
            save_image_grid(img, config.model_dir + '/%s.jpg' % key)
        print("Completing all work, iteration now start, consuming %s " %
              timer.runing_time_format)

        print("Start iterations...")
        sigma2 = 0.0
        count = 0
        with tf.io.TFRecordWriter(config.model_dir +
                                  '/CelebA64_rep.tfrecords') as writer:
            while True:
                try:
                    image_, label_, sigma2_plus_, rep_ = sess.run(
                        [image, label, sigma2_plus, z])
                    sigma2 += sigma2_plus_
                    count += 1
                    for n in range(image_.shape[0]):
                        tf_example = serialize_example(image_[n], label_[n],
                                                       rep_[n])
                        writer.write(tf_example)
                    if count % 100 == 0:
                        timer.update()
                        print('Complete %d bathes, consuming time %s' %
                              (count, timer.runing_time_format))
                except tf.errors.OutOfRangeError:
                    np.save(config.model_dir + '/sigma2.npy',
                            sigma2 / (count * config.batch_size))
                    print('Done!')
                    break
Exemple #3
0
def training_loop(config: Config):
    timer = Timer()
    opts = wae_opts.config_celebA

    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dataset = load_CelebA_KNN_from_record(
        config.record_dir + '/CelebA64knn5_rep.tfrecords', config.batch_size)
    dataset = dataset.make_initializable_iterator()
    laplace_sigma2 = 1.0 / (-np.log(config.laplace_a))
    global_step = tf.get_variable(
        name='global_step',
        initializer=tf.constant(0),
        trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    print("Constructing networks...")

    Encoder = vae.Encoder(opts, exceptions=['opt'], name='VAE_En')
    Decoder = vae.Decoder(opts, exceptions=['opt'], name='VAE_De')
    Discriminator = vae.Discriminator(opts, exceptions=['opt'])

    valina_encoder = vae_.Encoder(256, exceptions=['opt'], name='Encoder')

    def lip_metric(inputs):
        return inputs

    def d_metric(inputs):
        _, _, outputs = valina_encoder(inputs, True)
        return outputs

    def generator(inputs):
        outputs = Decoder(inputs, True, False)
        return outputs

    def lip_generator(inputs):
        _, _, outputs = Encoder(inputs, True)
        return outputs

    PPL = ppl.PPL_mnist(epsilon=0.01,
                        sampling='full',
                        generator=generator,
                        d_metric=d_metric)
    Lip_PPL = ppl.PPL_mnist(epsilon=0.01,
                            sampling='full',
                            generator=lip_generator,
                            d_metric=lip_metric)
    learning_rate = tf.train.exponential_decay(config.lr,
                                               global_step,
                                               config.decay_step,
                                               config.decay_coef,
                                               staircase=False)
    solver = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                    name='opt',
                                    beta2=config.beta2)
    adv_solver = tf.train.AdamOptimizer(learning_rate=2 * learning_rate,
                                        name='opt',
                                        beta1=opts['adam_beta1'])

    print("Building tensorflow graph...")

    def train_step(data):
        image, rep, label, neighbour, index = data
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        x = Decoder(z, is_training=True, flatten=False)
        with tf.variable_scope('reconstruction_loss'):
            # recon_loss = - tf.reduce_mean(tf.reduce_sum(
            #     image * tf.log(x + EPS) + (1 - image) * tf.log(1 - x + EPS), [1, 2, 3]))
            recon_loss = 0.05 * tf.reduce_mean(
                tf.reduce_sum(tf.square(image - x), [1, 2, 3]))
        with tf.variable_scope('smooth_loss'):
            mask = make_mask(neighbour, index)
            s_w = mask * smoother_weight(
                rep, 'heat', sigma2=laplace_sigma2, mask=mask)
            smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda
            s_w_mean = tf.reduce_mean(
                s_w) * config.batch_size * config.batch_size / (
                    tf.reduce_sum(mask) + EPS)
        with tf.variable_scope('wae_penalty'):
            Pz = tf.random.normal(shape=[config.batch_size, config.dim_z],
                                  mean=0.0,
                                  stddev=1.0)
            logits_Pz = Discriminator(Pz, True)
            logits_Qz = Discriminator(z, True)
            loss_Pz = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=logits_Pz, labels=tf.ones_like(logits_Pz)))
            loss_Qz = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=logits_Qz, labels=tf.zeros_like(logits_Qz)))
            loss_Qz_trick = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=logits_Qz, labels=tf.ones_like(logits_Qz)))
            loss_adv = config.wae_lambda * (loss_Pz + loss_Qz)
            loss_match = config.wae_lambda * loss_Qz_trick
        # loss = kl_divergence + recon_loss + smooth_loss
        loss = loss_match + recon_loss + smooth_loss
        add_global = global_step.assign_add(1)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies([add_global] + update_ops):
            opt = solver.minimize(loss,
                                  var_list=Encoder.trainable_variables +
                                  Decoder.trainable_variables)
            with tf.control_dependencies([opt]):
                l1, l2, l3, l4, l5 = tf.identity(loss), tf.identity(recon_loss), \
                                     tf.identity(loss_match), tf.identity(smooth_loss), tf.identity(s_w_mean)
        with tf.control_dependencies([add_global] + update_ops):
            d_opt = adv_solver.minimize(
                loss_adv, var_list=Discriminator.trainable_variables)
            with tf.control_dependencies([d_opt]):
                l6 = tf.identity(loss_adv)
        return l1, l2, l3, l4, l5, l6

    loss, r_loss, m_loss, s_loss, s_w, a_loss = train_step(dataset.get_next())

    def pretrain(data):
        image, rep, label, neighbour, index = data
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        Pz = tf.random.normal(shape=[config.batch_size, config.dim_z],
                              mean=0.0,
                              stddev=1.0)
        mean_pz = tf.reduce_mean(Pz, axis=0, keep_dims=True)
        mean_qz = tf.reduce_mean(z, axis=0, keep_dims=True)
        mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz))
        cov_pz = tf.matmul(Pz - mean_pz, Pz - mean_pz,
                           transpose_a=True) / (config.batch_size - 1)
        cov_qz = tf.matmul(z - mean_qz, z - mean_qz,
                           transpose_a=True) / (config.batch_size - 1)
        cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz))
        pretrain_loss = cov_loss + mean_loss
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = solver.minimize(pretrain_loss,
                                  var_list=Encoder.trainable_variables)
            with tf.control_dependencies([opt]):
                p_loss = tf.identity(pretrain_loss)
        return p_loss

    p_loss = pretrain(dataset.get_next())

    print("Building eval module...")

    fixed_z = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z0 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z1 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32, (config.example_nums, ) + (64, 64, 3))
    fixed_x0 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + (64, 64, 3))
    fixed_x1 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + (64, 64, 3))
    input_dict = {
        'fixed_z': fixed_z,
        'fixed_z0': fixed_z0,
        'fixed_z1': fixed_z1,
        'fixed_x': fixed_x,
        'fixed_x0': fixed_x0,
        'fixed_x1': fixed_x1,
        'num_midpoints': config.num_midpoints
    }

    def sample_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1
        })
        return out_dict

    o_dict = sample_step()

    def eval_step(img1, img2):
        z0 = tf.random.normal(shape=[config.batch_size, config.dim_z],
                              mean=0.0,
                              stddev=1.0)
        z1 = tf.random.normal(shape=[config.batch_size, config.dim_z],
                              mean=0.0,
                              stddev=1.0)
        _, _, img1_z = Encoder(img1, True)
        _, _, img2_z = Encoder(img2, True)
        ppl_sample_loss = PPL(z0, z1)
        ppl_de_loss = PPL(img1_z, img2_z)
        lip_loss = Lip_PPL(img1, img2)
        return ppl_sample_loss, ppl_de_loss, lip_loss

    img_1, _, _, _, _ = dataset.get_next()
    img_2, _, _, _, _ = dataset.get_next()
    ppl_sa_loss, ppl_de_loss, lip_loss = eval_step(img_1, img_2)

    print("Building init module...")
    with tf.init_scope():
        init = [tf.global_variables_initializer(), dataset.initializer]
        saver_e = tf.train.Saver(Encoder.restore_variables, max_to_keep=10)
        saver_d = tf.train.Saver(Decoder.restore_variables, max_to_keep=10)
        saver_v = tf.train.Saver(valina_encoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        saver_v.restore(sess, config.restore_s_dir)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing sample utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                  config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        print("Completing all work, iteration now start, consuming %s " %
              timer.runing_time_format)
        print("Start iterations...")
        print("Start pretraining of Encoder...")
        for iteration in range(500):
            p_loss_ = sess.run(p_loss)
            if iteration % 50 == 0:
                print("Pretrain_step %d, p_loss %f" % (iteration, p_loss_))
        print("Pretraining of Encoder Done! p_loss %f. Now start training..." %
              p_loss_)
        loss_list = []
        r_loss_list = []
        m_loss_list = []
        s_loss_list = []
        a_loss_list = []
        ppl_sa_list = []
        ppl_re_list = []
        lip_list = []
        for iteration in range(config.total_step):
            loss_, r_loss_, m_loss_, s_loss_, sw_sum_, lr_ = \
                sess.run([loss, r_loss, m_loss, s_loss, s_w, learning_rate])
            a_loss_ = sess.run(a_loss)
            if iteration % config.print_loss_per_steps == 0:
                loss_list.append(loss_)
                r_loss_list.append(r_loss_)
                m_loss_list.append(m_loss_)
                s_loss_list.append(s_loss_)
                a_loss_list.append(a_loss_)
                timer.update()
                print(
                    "step %d, loss %f, r_loss_ %f, m_loss_ %f, s_loss_ %f, sw %f,  a_loss %f "
                    "learning_rate % f, consuming time %s" %
                    (iteration, loss_, r_loss_, m_loss_, s_loss_,
                     np.mean(sw_sum_), a_loss_, lr_, timer.runing_time_format))
            if iteration % 1000 == 0:
                sa_loss_ = 0.0
                de_loss_ = 0.0
                lip_loss_ = 0.0
                for _ in range(200):
                    sa_p, de_p, lip_p = sess.run(
                        [ppl_sa_loss, ppl_de_loss, lip_loss])
                    sa_loss_ += sa_p
                    de_loss_ += de_p
                    lip_loss_ += lip_p
                sa_loss_ /= config.batch_size * 256
                de_loss_ /= config.batch_size * 256
                lip_loss_ /= config.batch_size * 256
                ppl_re_list.append(de_loss_)
                ppl_sa_list.append(sa_loss_)
                lip_list.append(lip_loss_)
                print("ppl_sample %f, ppl_resample %f, lipschitze %f" %
                      (sa_loss_, de_loss_, lip_loss_))
            if iteration % config.eval_per_steps == 0:
                o_dict_ = sess.run(o_dict, {
                    fixed_x: fixed_x_,
                    fixed_x0: fixed_x0_,
                    fixed_x1: fixed_x1_
                })
                for key in o_dict:
                    if not os.path.exists(config.model_dir +
                                          '/%06d' % iteration):
                        os.makedirs(config.model_dir + '/%06d' % iteration)
                    if o_dict_[key].ndim == 5:
                        img = o_dict_[key].transpose([0, 1, 4, 2, 3])
                    else:
                        img = o_dict_[key].transpose([0, 3, 1, 2])
                    save_image_grid(
                        img,
                        config.model_dir + '/%06d/%s.jpg' % (iteration, key))
            if iteration % config.save_per_steps == 0:
                saver_e.save(sess,
                             save_path=config.model_dir + '/en.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
                saver_d.save(sess,
                             save_path=config.model_dir + '/de.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
                metric_dict = {
                    'r': r_loss_list,
                    'm': m_loss_list,
                    's': s_loss_list,
                    'a': a_loss_list,
                    'psa': ppl_sa_list,
                    'pre': ppl_re_list,
                    'lip': lip_list
                }
                np.save(config.model_dir + '/%06d' % iteration + 'metric.npy',
                        metric_dict)
Exemple #4
0
def training_loop(config: Config):
    timer = Timer()
    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dset = get_dataset(config.dataset_name, config.tfds_dir,
                       config.gpu_nums * 2)
    dataset = dset.input_fn(config.batch_size, mode='train')
    dataset = dataset.make_initializable_iterator()

    eval_dataset = dset.input_fn(config.batch_size, mode='eval')
    eval_dataset = eval_dataset.make_initializable_iterator()
    global_step = tf.get_variable(
        name='global_step',
        initializer=tf.constant(0),
        trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    print("Constructing networks...")

    Encoder = vae.Encoder(config.dim_z, exceptions=['opt'], name='Encoder')
    Decoder = vae.Decoder(dset.image_shape, exceptions=['opt'], name='Decoder')
    learning_rate = tf.train.exponential_decay(config.lr,
                                               global_step,
                                               config.decay_step,
                                               config.decay_coef,
                                               staircase=False)
    solver = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                    name='opt',
                                    beta2=config.beta2)
    print("Building tensorflow graph...")

    def train_step(image):
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        x = Decoder(z, is_training=True)
        with tf.variable_scope('reconstruction_loss'):
            recon_loss = config.sigma**2 * tf.reduce_mean(
                tf.reduce_sum(tf.square(image - x), [1, 2, 3]))
        loss = recon_loss
        add_global = global_step.assign_add(1)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies([add_global] + update_ops):
            opt = solver.minimize(loss,
                                  var_list=Encoder.trainable_variables +
                                  Decoder.trainable_variables)
            with tf.control_dependencies([opt]):
                return tf.identity(loss)

    loss = train_step(dataset.get_next()[0])
    print("Building eval module...")

    fixed_z = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z0 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z1 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32,
                             (config.example_nums, ) + dset.image_shape)
    fixed_x0 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + dset.image_shape)
    fixed_x1 = tf.placeholder(tf.float32,
                              (config.example_nums, ) + dset.image_shape)
    input_dict = {
        'fixed_z': fixed_z,
        'fixed_z0': fixed_z0,
        'fixed_z1': fixed_z1,
        'fixed_x': fixed_x,
        'fixed_x0': fixed_x0,
        'fixed_x1': fixed_x1,
        'num_midpoints': config.num_midpoints
    }

    def sample_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1
        })
        return out_dict

    def eval_step(image):
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        x = Decoder(z, is_training=True)
        mse = tf.reduce_mean(
            tf.reduce_sum(tf.square(image - x), axis=[1, 2, 3]))
        return mse

    mse = eval_step(dataset.get_next()[0])

    o_dict = sample_step()

    print("Building init module...")
    with tf.init_scope():
        init = [
            tf.global_variables_initializer(), dataset.initializer,
            eval_dataset.initializer
        ]
        saver_e = tf.train.Saver(Encoder.restore_variables)
        saver_d = tf.train.Saver(Decoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing sample utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                  config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        print("Completing all work, iteration now start, consuming %s " %
              timer.runing_time_format)
        print("Start iterations...")
        for iteration in range(config.total_step):
            loss_, lr_ = sess.run([loss, learning_rate])
            if iteration % config.print_loss_per_steps == 0:
                mse_ = sess.run(mse)
                timer.update()
                print("step %d, loss %f, mse %f, "
                      "learning_rate % f, consuming time %s" %
                      (iteration, loss_, mse_, lr_, timer.runing_time_format))
            if iteration % config.eval_per_steps == 0:
                o_dict_ = sess.run(o_dict, {
                    fixed_x: fixed_x_,
                    fixed_x0: fixed_x0_,
                    fixed_x1: fixed_x1_
                })
                for key in o_dict:
                    if not os.path.exists(config.model_dir +
                                          '/%06d' % iteration):
                        os.makedirs(config.model_dir + '/%06d' % iteration)
                    if o_dict_[key].ndim == 5:
                        img = o_dict_[key].transpose([0, 1, 4, 2, 3])
                    else:
                        img = o_dict_[key].transpose([0, 3, 1, 2])
                    save_image_grid(
                        img,
                        config.model_dir + '/%06d/%s.jpg' % (iteration, key))
            if iteration % config.save_per_steps == 0:
                saver_e.save(sess,
                             save_path=config.model_dir + '/en.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
                saver_d.save(sess,
                             save_path=config.model_dir + '/de.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)