Ejemplo n.º 1
0
def training_loop(config: Config):
    timer = Timer()
    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dataset = load_mnist_from_record(
        config.record_dir + '/Mnist20_rep.tfrecords', config.batch_size)
    dataset = dataset.make_initializable_iterator()
    laplace_sigma2 = np.load(config.record_dir +
                             '/sigma2.npy') / (-np.log(config.laplace_a))

    global_step = tf.get_variable(name='global_step',
                                  initializer=tf.constant(0),
                                  trainable=False)
    print("Constructing networks...")
    Encoder = vae.Encoder(config.dim_z,
                          config.e_hidden_num,
                          exceptions=['opt'],
                          name='Encoder')
    Decoder = vae.Decoder(config.img_shape,
                          config.d_hidden_num,
                          exceptions=['opt'],
                          name='Decoder')
    learning_rate = tf.train.exponential_decay(config.lr,
                                               global_step,
                                               config.decay_step,
                                               config.decay_coef,
                                               staircase=False)
    solver = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                    name='opt',
                                    beta2=config.beta2)
    print("Building tensorflow graph...")

    def train_step(data):
        image, rep, label = data
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        x = Decoder(z, is_training=True, flatten=False)
        with tf.variable_scope('kl_divergence'):
            kl_divergence = -tf.reduce_mean(
                tf.reduce_sum(
                    0.5 *
                    (1 + log_sigma_z - mu_z**2 - tf.exp(log_sigma_z)), 1))
        with tf.variable_scope('reconstruction_loss'):
            recon_loss = -tf.reduce_mean(
                tf.reduce_sum(
                    image * tf.log(x + EPS) +
                    (1 - image) * tf.log(1 - x + EPS), [1, 2, 3]))
        with tf.variable_scope('smooth_loss'):
            s_w = smoother_weight(rep, 'heat', sigma2=laplace_sigma2)
            smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda
        loss = kl_divergence + recon_loss + smooth_loss
        add_global = global_step.assign_add(1)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies([add_global] + update_ops):
            opt = solver.minimize(loss,
                                  var_list=Encoder.trainable_variables +
                                  Decoder.trainable_variables)
            with tf.control_dependencies([opt]):
                return tf.identity(loss), tf.identity(recon_loss), \
                       tf.identity(kl_divergence), tf.identity(smooth_loss), tf.identity(s_w)

    loss, r_loss, kl_loss, s_loss, s_w = train_step(dataset.get_next())
    print("Building eval module...")

    fixed_z = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z0 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z1 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32,
                             [config.example_nums] + config.img_shape)
    fixed_x0 = tf.placeholder(tf.float32,
                              [config.example_nums] + config.img_shape)
    fixed_x1 = tf.placeholder(tf.float32,
                              [config.example_nums] + config.img_shape)
    input_dict = {
        'fixed_z': fixed_z,
        'fixed_z0': fixed_z0,
        'fixed_z1': fixed_z1,
        'fixed_x': fixed_x,
        'fixed_x0': fixed_x0,
        'fixed_x1': fixed_x1,
        'num_midpoints': config.num_midpoints
    }

    def eval_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1
        })
        return out_dict

    o_dict = eval_step()

    print("Building init module...")
    with tf.init_scope():
        init = [tf.global_variables_initializer(), dataset.initializer]
        saver_e = tf.train.Saver(Encoder.restore_variables)
        saver_d = tf.train.Saver(Decoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing eval utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                  config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        print("Completing all work, iteration now start, consuming %s " %
              timer.runing_time_format)
        print("Start iterations...")
        for iteration in range(config.total_step):
            loss_, r_loss_, kl_loss_, s_loss_, sw_sum_, lr_ = \
                sess.run([loss, r_loss, kl_loss, s_loss, s_w, learning_rate])
            if iteration % config.print_loss_per_steps == 0:
                timer.update()
                print(
                    "step %d, loss %f, r_loss_ %f, kl_loss_ %f, s_loss_ %f, sw_prod %f, "
                    "learning_rate % f, consuming time %s" %
                    (iteration, loss_, r_loss_, kl_loss_, s_loss_,
                     np.prod(sw_sum_)
                     **(1 / 255**2), lr_, timer.runing_time_format))
            if iteration % config.eval_per_steps == 0:
                o_dict_ = sess.run(o_dict, {
                    fixed_x: fixed_x_,
                    fixed_x0: fixed_x0_,
                    fixed_x1: fixed_x1_
                })
                for key in o_dict:
                    if not os.path.exists(config.model_dir +
                                          '/%06d' % iteration):
                        os.makedirs(config.model_dir + '/%06d' % iteration)
                    save_image_grid(
                        o_dict_[key],
                        config.model_dir + '/%06d/%s.jpg' % (iteration, key))
            if iteration % config.save_per_steps == 0:
                saver_e.save(sess,
                             save_path=config.model_dir + '/en.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
                saver_d.save(sess,
                             save_path=config.model_dir + '/de.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
Ejemplo n.º 2
0
def training_loop(config: Config):
    timer = Timer()
    opts = w_config.config_mnist
    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dataset = load_mnist_KNN_from_record(config.record_dir + '/Mnist20knn5_rep.tfrecords', config.batch_size)
    dataset = dataset.make_initializable_iterator()
    # laplace_sigma2 = np.load(config.record_dir + '/knn5sigma2.npy') / (-np.log(config.laplace_a))
    laplace_sigma2 = 1.0 / (-np.log(config.laplace_a))

    global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
    print("Constructing networks...")
    Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='Encoder')
    Decoder = vae.Decoder(config.img_shape, config.d_hidden_num, exceptions=['opt'], name='Decoder')
    valina_encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='VAE_En')

    def lip_metric(inputs):
        return inputs

    def d_metric(inputs):
        _, _, outputs = valina_encoder(inputs, True)
        return outputs

    def generator(inputs):
        outputs = Decoder(inputs, True, False)
        return outputs

    def lip_generator(inputs):
        _, _, outputs = Encoder(inputs, True)
        return outputs

    PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=generator, d_metric=d_metric)
    Lip_PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=lip_generator, d_metric=lip_metric)

    learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step,
                                               config.decay_coef, staircase=False)
    solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta1=opts['adam_beta1'])
    print("Building tensorflow graph...")

    def train_step(data):
        image, rep, label, neighbour, index = data
        mu_z, log_sigma_z, z = Encoder(image, is_training=True)
        x = Decoder(z, is_training=True, flatten=False)
        with tf.variable_scope('kl_divergence'):
            kl_divergence = - config.wae_lambda * tf.reduce_mean(tf.reduce_sum(
                0.5 * (1 + log_sigma_z - mu_z ** 2 - tf.exp(log_sigma_z)), 1))
        with tf.variable_scope('reconstruction_loss'):
            recon_loss = - tf.reduce_mean(tf.reduce_sum(
                image * tf.log(x + EPS) + (1 - image) * tf.log(1 - x + EPS), [1, 2, 3]))
            # recon_loss = 0.05 * tf.reduce_mean(tf.reduce_sum(tf.square(image - x), [1, 2, 3]))
        with tf.variable_scope('smooth_loss'):
            mask = make_mask(neighbour, index)
            s_w = mask * smoother_weight(rep, 'heat', sigma2=laplace_sigma2, mask=mask)
            smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda
            s_w_mean = tf.reduce_mean(s_w) * config.batch_size * config.batch_size / (tf.reduce_sum(mask) + EPS)

        loss = kl_divergence + recon_loss + smooth_loss
        # loss = loss_match + recon_loss
        add_global = global_step.assign_add(1)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies([add_global] + update_ops):
            opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables)
            with tf.control_dependencies([opt]):
                l1, l2, l3, l4, l5 = tf.identity(loss), tf.identity(recon_loss), \
                       tf.identity(kl_divergence), tf.identity(smooth_loss), tf.identity(s_w_mean)

        return l1, l2, l3, l4, l5

    loss, r_loss, kl_loss, s_loss, s_w = train_step(dataset.get_next())

    # def pretrain(data):
    #     image, rep, label, neighbour, index = data
    #     mu_z, log_sigma_z, z = Encoder(image, is_training=True)
    #     Pz = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0)
    #     mean_pz = tf.reduce_mean(Pz, axis=0, keep_dims=True)
    #     mean_qz = tf.reduce_mean(z, axis=0, keep_dims=True)
    #     mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz))
    #     cov_pz = tf.matmul(Pz - mean_pz, Pz - mean_pz,
    #                        transpose_a=True) / (config.batch_size - 1)
    #     cov_qz = tf.matmul(z - mean_qz, z - mean_qz,
    #                        transpose_a=True) / (config.batch_size - 1)
    #     cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz))
    #     pretrain_loss = cov_loss + mean_loss
    #     update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    #     with tf.control_dependencies(update_ops):
    #         opt = solver.minimize(pretrain_loss, var_list=Encoder.trainable_variables)
    #         with tf.control_dependencies([opt]):
    #             p_loss = tf.identity(pretrain_loss)
    #     return p_loss
    # p_loss = pretrain(dataset.get_next())

    print("Building eval module...")

    fixed_z = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32)
    fixed_z0 = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32)
    fixed_z1 = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape)
    fixed_x0 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape)
    fixed_x1 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape)
    input_dict = {'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x,
                  'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints}

    def sample_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1})
        return out_dict

    o_dict = sample_step()

    def eval_step(img1, img2):
        z0 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0)
        z1 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0)
        _, _, img1_z = Encoder(img1, True)
        _, _, img2_z = Encoder(img2, True)
        ppl_sample_loss = PPL(z0, z1)
        ppl_de_loss = PPL(img1_z, img2_z)
        lip_loss = Lip_PPL(img1, img2)
        return ppl_sample_loss, ppl_de_loss, lip_loss

    img_1, _, _, _, _ = dataset.get_next()
    img_2, _, _, _, _ = dataset.get_next()
    ppl_sa_loss, ppl_de_loss, lip_loss = eval_step(img_1, img_2)

    print("Building init module...")
    with tf.init_scope():
        init = [tf.global_variables_initializer(), dataset.initializer]
        saver_e = tf.train.Saver(Encoder.restore_variables, max_to_keep=10)
        saver_d = tf.train.Saver(Decoder.restore_variables, max_to_keep=10)
        saver_v = tf.train.Saver(valina_encoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        saver_v.restore(sess, config.restore_s_dir)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing eval utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size)
        print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format)
        print("Start iterations...")
        # print("Start pretraining of Encoder...")
        # for iteration in range(500):
        #     p_loss_ = sess.run(p_loss)
        #     if iteration % 50 == 0:
        #         print("Pretrain_step %d, p_loss %f" % (iteration, p_loss_))
        # print("Pretraining of Encoder Done! p_loss %f. Now start training..." % p_loss_)
        loss_list = []
        r_loss_list = []
        kl_loss_list = []
        s_loss_list = []
        ppl_sa_list = []
        ppl_re_list = []
        lip_list = []
        for iteration in range(config.total_step):
            loss_, r_loss_, m_loss_, s_loss_, sw_sum_, lr_ = \
                sess.run([loss, r_loss, kl_loss, s_loss, s_w, learning_rate])
            if iteration % config.print_loss_per_steps == 0:
                loss_list.append(loss_)
                r_loss_list.append(r_loss_)
                kl_loss_list.append(m_loss_)
                s_loss_list.append(s_loss_)
                timer.update()
                print("step %d, loss %f, r_loss_ %f, kl_loss_ %f, s_loss_ %f, sw %f, "
                      "learning_rate % f, consuming time %s" %
                      (iteration, loss_, r_loss_, m_loss_, s_loss_, np.mean(sw_sum_),
                       lr_, timer.runing_time_format))
            if iteration % 1000 == 0:
                sa_loss_ = 0.0
                de_loss_ = 0.0
                lip_loss_ = 0.0
                for _ in range(200):
                    sa_p, de_p, lip_p = sess.run([ppl_sa_loss, ppl_de_loss, lip_loss])
                    sa_loss_ += sa_p
                    de_loss_ += de_p
                    lip_loss_ += lip_p
                sa_loss_ /= config.batch_size * 256
                de_loss_ /= config.batch_size * 256
                lip_loss_ /= config.batch_size * 256
                ppl_re_list.append(de_loss_)
                ppl_sa_list.append(sa_loss_)
                lip_list.append(lip_loss_)
                print("ppl_sample %f, ppl_resample %f, lipschitze %f" % (sa_loss_, de_loss_, lip_loss_))
            if iteration % config.eval_per_steps == 0:
                o_dict_ = sess.run(o_dict, {fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_})
                for key in o_dict:
                    if not os.path.exists(config.model_dir + '/%06d' % iteration):
                        os.makedirs(config.model_dir + '/%06d' % iteration)
                    save_image_grid(o_dict_[key], config.model_dir + '/%06d/%s.jpg' % (iteration, key))
            if iteration % config.save_per_steps == 0:
                saver_e.save(sess, save_path=config.model_dir + '/en.ckpt',
                             global_step=iteration, write_meta_graph=False)
                saver_d.save(sess, save_path=config.model_dir + '/de.ckpt',
                             global_step=iteration, write_meta_graph=False)
                metric_dict = {'r': r_loss_list, 'm': kl_loss_list, 's': s_loss_list,
                               'psa': ppl_sa_list, 'pre': ppl_re_list, 'lip': lip_list}
                np.save(config.model_dir + '/%06d' % iteration + 'metric.npy', metric_dict)
Ejemplo n.º 3
0
def training_loop(config: Config):
    timer = Timer()
    print('Task name %s' % config.task_name)
    strategy = tf.distribute.MirroredStrategy()
    print('Loading %s dataset...' % config.dataset_name)
    dset = get_dataset(config.dataset_name, config.tfds_dir,
                       config.gpu_nums * 2)
    dataset = dset.input_fn(config.batch_size, mode='train')
    dataset = strategy.experimental_distribute_dataset(dataset)
    dataset = dataset.make_initializable_iterator()

    eval_dataset = dset.input_fn(config.batch_size, mode='eval')
    eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
    eval_dataset = eval_dataset.make_initializable_iterator()
    with strategy.scope():
        global_step = tf.get_variable(
            name='global_step',
            initializer=tf.constant(0),
            trainable=False,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        print("Constructing networks...")
        Encoder = vae.Encoder(config.dim_z,
                              config.e_hidden_num,
                              exceptions=['opt'],
                              name='Encoder')
        Decoder = vae.Decoder(config.img_shape,
                              config.d_hidden_num,
                              exceptions=['opt'],
                              name='Decoder')
        learning_rate = tf.train.exponential_decay(config.lr,
                                                   global_step,
                                                   config.decay_step,
                                                   config.decay_coef,
                                                   staircase=False)
        solver = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                        name='opt',
                                        beta2=config.beta2)
        print("Building tensorflow graph...")

        def train_step(image):
            mu_z, log_sigma_z, z = Encoder(image, is_training=True)
            x = Decoder(z, is_training=True, flatten=False)
            with tf.variable_scope('kl_divergence'):
                kl_divergence = 0.5 * (1 + log_sigma_z - mu_z**2 -
                                       tf.exp(log_sigma_z))
            with tf.variable_scope('reconstruction_loss'):
                recon_loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(labels=image,
                                                            logits=x))
            loss = kl_divergence + recon_loss
            add_global = global_step.assign_add(1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies([add_global] + update_ops):
                opt = solver.minimize(loss,
                                      var_list=Encoder.trainable_variables +
                                      Decoder.trainable_variables)
                with tf.control_dependencies([opt]):
                    return tf.identity(loss), tf.identity(recon_loss), \
                           tf.identity(kl_divergence)

        loss, r_loss, kl_loss = strategy.experimental_run_v2(
            train_step, (dataset.get_next()[0], ))
        loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
        r_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                 r_loss,
                                 axis=None)
        kl_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                  kl_loss,
                                  axis=None)
        print("Building eval module...")

        fixed_z = tf.constant(
            np.random.normal(size=[config.example_nums, config.dim_z]),
            dtype=tf.float32)
        fixed_z0 = tf.constant(
            np.random.normal(size=[config.example_nums, config.dim_z]),
            dtype=tf.float32)
        fixed_z1 = tf.constant(
            np.random.normal(size=[config.example_nums, config.dim_z]),
            dtype=tf.float32)
        fixed_x = tf.placeholder(tf.float32,
                                 [config.example_nums] + config.img_shape)
        fixed_x0 = tf.placeholder(tf.float32,
                                  [config.example_nums] + config.img_shape)
        fixed_x1 = tf.placeholder(tf.float32,
                                  [config.example_nums] + config.img_shape)
        input_dict = {
            'fixed_z': fixed_z,
            'fixed_z0': fixed_z0,
            'fixed_z1': fixed_z1,
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1,
            'num_midpoints': config.num_midpoints
        }

        def eval_step():
            out_dict = generate_sample(Decoder, input_dict)
            out_dict.update(reconstruction_sample(Encoder, Decoder,
                                                  input_dict))
            return out_dict

        if config.gpu_nums == 1:
            o_dict = strategy.experimental_run_v2(eval_step, ())
        else:
            o_dict = concate_PerReplica(
                strategy.experimental_run_v2(eval_step, ()))

        print("Building init module...")
        with tf.init_scope():
            init = [
                tf.global_variables_initializer(), dataset.initializer,
                eval_dataset.initializer
            ]
            saver_e = tf.train.Saver(Encoder.restore_variables)
            saver_d = tf.train.Saver(Decoder.restore_variables)

        print('Starting training...')
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)
            if config.resume:
                print("Restore vae...")
                saver_e.restore(sess, config.restore_e_dir)
                saver_d.restore(sess, config.restore_d_dir)
            timer.update()
            print('Preparing eval utils...')

            fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                      config.batch_size)
            fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                       config.batch_size)
            fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                       config.batch_size)
            print("Completing all work, iteration now start, consuming %s " %
                  timer.runing_time_format)
            print("Start iterations...")
            for iteration in range(config.total_step):
                loss_, r_loss_, kl_loss_, lr_ = sess.run(
                    [loss, r_loss, kl_loss, learning_rate])
            if iteration % config.print_loss_per_steps == 0:
                timer.update()
                print(
                    "step %d, loss %f, r_loss_ %f, kl_loss_ %f, learning_rate % f, consuming time %s"
                    % (iteration, loss_, r_loss_, kl_loss_, lr_,
                       timer.runing_time_format))
            if iteration % config.eval_per_steps == 0:
                sess.run(o_dict, {
                    fixed_x: fixed_x_,
                    fixed_x0: fixed_x0_,
                    fixed_x1: fixed_x1
                })
                for key in o_dict:
                    save_image_grid(
                        o_dict[key],
                        config.model_dir + '/%s%06d' % (key, iteration))
            if iteration % config.save_per_steps == 0:
                saver_e.save(sess,
                             save_path=config.model_dir + '/en.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
                saver_d.save(sess,
                             save_path=config.model_dir + '/de.ckpt',
                             global_step=iteration,
                             write_meta_graph=False)
Ejemplo n.º 4
0
def training_loop(config: Config):
    timer = Timer()
    print('Task name %s' % config.task_name)
    print('Loading %s dataset...' % config.dataset_name)
    dset = get_dataset(config.dataset_name, config.tfds_dir,
                       config.gpu_nums * 2)
    dataset = train_input_fn(dset, config.batch_size)
    dataset = dataset.make_initializable_iterator()
    print("Constructing networks...")
    Encoder = vae.Encoder(config.dim_z,
                          config.e_hidden_num,
                          exceptions=['opt'],
                          name='VAE_En')
    Decoder = vae.Decoder(config.img_shape,
                          config.d_hidden_num,
                          exceptions=['opt'],
                          name='VAE_De')
    print("Building tensorflow graph...")
    image, label = dataset.get_next()
    _, _, z = Encoder(image, is_training=True)
    sigma2_plus = compute_sigma2(z)
    print("Building eval module...")

    fixed_z = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z0 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_z1 = tf.constant(
        np.random.normal(size=[config.example_nums, config.dim_z]),
        dtype=tf.float32)
    fixed_x = tf.placeholder(tf.float32,
                             [config.example_nums] + config.img_shape)
    fixed_x0 = tf.placeholder(tf.float32,
                              [config.example_nums] + config.img_shape)
    fixed_x1 = tf.placeholder(tf.float32,
                              [config.example_nums] + config.img_shape)
    input_dict = {
        'fixed_z': fixed_z,
        'fixed_z0': fixed_z0,
        'fixed_z1': fixed_z1,
        'fixed_x': fixed_x,
        'fixed_x0': fixed_x0,
        'fixed_x1': fixed_x1,
        'num_midpoints': config.num_midpoints
    }

    def sample_step():
        out_dict = generate_sample(Decoder, input_dict)
        out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict))
        out_dict.update({
            'fixed_x': fixed_x,
            'fixed_x0': fixed_x0,
            'fixed_x1': fixed_x1
        })
        return out_dict

    o_dict = sample_step()

    print("Building init module...")
    with tf.init_scope():
        init = [tf.global_variables_initializer(), dataset.initializer]
        saver_e = tf.train.Saver(Encoder.restore_variables)
        saver_d = tf.train.Saver(Decoder.restore_variables)

    print('Starting training...')
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        sess.run(init)
        if config.resume:
            print("Restore vae...")
            saver_e.restore(sess, config.restore_e_dir)
            saver_d.restore(sess, config.restore_d_dir)
        timer.update()
        print('Preparing sample utils...')

        fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                  config.batch_size)
        fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums,
                                   config.batch_size)
        o_dict_ = sess.run(o_dict, {
            fixed_x: fixed_x_,
            fixed_x0: fixed_x0_,
            fixed_x1: fixed_x1_
        })
        for key in o_dict_:
            save_image_grid(o_dict_[key], config.model_dir + '/%s.jpg' % key)
        print("Completing all work, iteration now start, consuming %s " %
              timer.runing_time_format)

        print("Start iterations...")
        sigma2 = 0.0
        count = 0
        with tf.io.TFRecordWriter(config.model_dir +
                                  '/Mnist20_rep.tfrecords') as writer:
            while True:
                try:
                    image_, label_, sigma2_plus_, rep_ = sess.run(
                        [image, label, sigma2_plus, z])
                    sigma2 += sigma2_plus_
                    count += 1
                    for n in range(image_.shape[0]):
                        tf_example = serialize_example(image_[n], label_[n],
                                                       rep_[n])
                        writer.write(tf_example)
                    if count % 100 == 0:
                        timer.update()
                        print('Complete %d bathes, consuming time %s' %
                              (count, timer.runing_time_format))
                except tf.errors.OutOfRangeError:
                    np.save(config.model_dir + '/sigma2.npy',
                            sigma2 / (count * config.batch_size))
                    print('Done!')
                    break