예제 #1
0
def main():
    with tf.device(config.device):
        t = build_graph(is_test=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True)) as sess:

        logger.info(config.ckpt_path)
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))
        logger.info("Loading model completely")
        z_latent = sampler_switch(config)
        d_q = sess.run(t.p_o,
                       feed_dict={
                           t.z_e: dp.test.e,
                           t.x_c: dp.test.c,
                           t.z_l: z_latent,
                           t.p_in: dp.test.rd,
                       })
        r_p = sess.run(t.p_i,
                       feed_dict={
                           t.x_c: dp.test.c,
                           t.z_l: z_latent,
                           t.z_e: dp.test.e,
                           t.p_in: dp.test.rd
                       })

        # inverse the scaled output
        qm, qr, rdm, rdr = dp.out.qm, dp.out.qr, dp.out.rdm, dp.out.rdr
        actual_Q = anti_norm(dp.test.q, qm, qr)
        result_Q = anti_norm(d_q, qm, qr)
        actual_r = anti_norm(dp.test.rd, rdm, rdr)
        result_r = anti_norm(r_p, rdm, rdr)

        # save the result
        ensemble = {
            'actual_Q': actual_Q,
            'result_Q': result_Q,
            'actual_r': actual_r,
            'result_r': result_r
        }

        path = os.path.join(config.logs_path, config.description + '-test.pkl')
        pickle_save(ensemble, 'test_result', path)
        copy_file(path, config.history_test_path)

        # visualize the process
        vis.cplot(actual_Q[:, 0], result_Q[:, 0], ['Q1', 'origin', 'modify'],
                  config.t_p)
        vis.cplot(actual_Q[:, 1], result_Q[:, 1], ['Q2', 'origin', 'modify'],
                  config.t_p)
        for num in range(6):
            vis.cplot(actual_r[:, num], result_r[:, num],
                      ['R{}'.format(num + 1), 'origin', 'modify'], config.t_p)
예제 #2
0
def main():
    with tf.device(config.device):
        t = build_graph(is_test=True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True)) as sess:
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))

        q_errors = []
        r_adjs = []
        z_adjs = []

        z_true = sess.run(t.z_img, feed_dict={t.x: dp.test.rd})

        for time in range(test_times):
            z_latent = sampler_switch(config)
            q_error = sess.run(t.dq,
                               feed_dict={
                                   t.z_e: dp.test.e,
                                   t.x_c: dp.test.c,
                                   t.z_l: z_latent,
                                   t.p_in: dp.test.rd,
                                   t.p_t: dp.test.q,
                               })
            r_adj = sess.run(t.x_lat,
                             feed_dict={
                                 t.x_c: dp.test.c,
                                 t.z_l: z_latent,
                                 t.z_e: dp.test.e,
                             })
            z_adj = sess.run(t.z_img, feed_dict={t.x: r_adj})

            q_errors.append(q_error)
            r_adjs.append(r_adj)
            z_adjs.append(z_adj)

        q_errors = (np.array(q_errors) - np.expand_dims(dp.test.e, axis=0))**2
        r_adjs = np.array(r_adjs).reshape(-1, config.ndim_x)
        z_adjs = np.array(z_adjs).reshape(-1, config.ndim_z)

        pickle_save([q_errors, r_adjs, z_adjs, z_true],
                    ["productions", "adjustment", "latent_variables"],
                    '{}/{}-metric_plus.pkl'.format(config.logs_path,
                                                   config.description))
def main(db='gs'):
    tf.reset_default_graph()
    config.batch_size = dp.valid.num_sample
    config.distribution_sampler = db
    with tf.device(config.device):
        t = build_graph(is_test=True)

    with tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=True
                                  )) as sess:
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))

        q_errors = []
        r_adjs = []
        z_adjs = []

        z_true = sess.run(t.z_img, feed_dict={
            t.x: dp.test.rd
        })

        for time in range(test_times):
            z_latent = sampler_switch(config)
            q_error = sess.run(t.dq, feed_dict={
                t.z_e: dp.test.e,
                t.x_c: dp.test.c,
                t.z_l: z_latent,
                t.p_in: dp.test.rd,
                t.p_t: dp.test.q,
            })
            r_adj = sess.run(t.x_lat, feed_dict={
                t.x_c: dp.test.c,
                t.z_l: z_latent,
                t.z_e: dp.test.e,
            })
            z_adj = sess.run(t.z_img, feed_dict={
                t.x: r_adj
            })

            q_errors.append(q_error)
            r_adjs.append(r_adj)
            z_adjs.append(z_adj)

        q_errors = (np.array(q_errors) - np.expand_dims(dp.test.e, axis=0))**2
        r_adjs = np.array(r_adjs).reshape(-1, config.ndim_x)
        z_adjs = np.array(z_adjs).reshape(-1, config.ndim_z)

    # revise the number of batch size
    tf.reset_default_graph()
    config.batch_size = dp.train_l.num_sample

    with tf.device(config.device):
        t = build_graph(is_test=True)
    with tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True,
                                  log_device_placement=True
                                  )) as sess:
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))
        z_train = sess.run(t.z_img, feed_dict={
            t.x: dp.train.rd
        })


    pickle_save([q_errors, r_adjs, z_adjs, z_true,z_train],
                ["productions", "adjustment", "latent_variables"],
                '{}/{}-metric_plus3.pkl'.format(config.logs_path, config.get_description()))
    print('{}/{}-metric_plus3.pkl have been saved'.format(config.logs_path, config.get_description()))
예제 #4
0
def main(run_load_from_file=False):
    config = BaseConfig()
    config.folder_init()
    dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch')
    max_epoch = 500
    batch_size_l = config.batch_size
    path = os.path.join(config.logs_path, config.description + '-train.pkl')

    # training
    with tf.device(config.device):
        h = build_graph()

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)
    sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    saver = tf.train.Saver(max_to_keep=2)

    with tf.Session(config=sess_config) as sess:
        '''
         Load from checkpoint or start a new session

        '''
        if run_load_from_file:
            saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))
            training_epoch_loss, _ = pickle_load(path)
        else:
            sess.run(tf.global_variables_initializer())
            training_epoch_loss = []

        # Recording loss per epoch
        process = Process()
        lr_schedule = create_lr_schedule(lr_base=2e-4,
                                         decay_rate=0.1,
                                         decay_epochs=500,
                                         truncated_epoch=2000,
                                         mode=config.lr_schedule)
        for epoch in range(max_epoch):
            process.start_epoch()
            '''
            Learning rate generator

            '''
            learning_rate = lr_schedule(epoch)
            # Recording loss per iteration
            training_iteration_loss = []
            sum_loss_rest = 0
            sum_loss_dcm = 0
            sum_loss_gen = 0

            process_iteration = Process()
            data_size = dp.train_l.num_sample
            num_batch = data_size // config.batch_size
            for i in range(num_batch + 1):
                process_iteration.start_epoch()
                # Inputs
                # sample from data distribution
                batch_l = dp.train_l.next_batch(batch_size_l)
                z_prior = sampler.sampler_switch(config)
                # adversarial phase for discriminator_z
                _, Dz_err = sess.run([h.opt_dz, h.loss_dz],
                                     feed_dict={
                                         h.x: batch_l.x,
                                         h.z_p: z_prior,
                                         h.lr: learning_rate,
                                     })
                z_latent = sampler.sampler_switch(config)
                _, Di_err = sess.run(
                    [h.opt_dimg, h.loss_dimg],
                    feed_dict={
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.x_s: batch_l.x,
                        h.lr: learning_rate,
                    })
                z_latent = sampler.sampler_switch(config)
                # reconstruction_phase
                _, R_err, Ez_err, Gi_err, GE_err, EG_err = sess.run(
                    fetches=[
                        h.opt_r, h.loss_r, h.loss_e, h.loss_d, h.loss_l,
                        h.loss_eg
                    ],
                    feed_dict={
                        h.x: batch_l.x,
                        h.z_p: z_prior,
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.x_s: batch_l.x,
                        h.lr: learning_rate,
                    })
                # process phase
                _, P_err = sess.run([h.opt_p, h.loss_p],
                                    feed_dict={
                                        h.p_i: batch_l.rd,
                                        h.p_ot: batch_l.q,
                                        h.lr: learning_rate
                                    })
                # push process to normal
                z_latent = sampler.sampler_switch(config)
                _, GP_err = sess.run(
                    [h.opt_q, h.loss_q],
                    feed_dict={
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.p_in: batch_l.rd,
                        h.p_ot: batch_l.q,
                        h.lr: learning_rate,
                    })
                # recording loss function
                training_iteration_loss.append([
                    R_err, Ez_err, Gi_err, GE_err, EG_err, Dz_err, Di_err,
                    P_err, GP_err
                ])
                sum_loss_rest += R_err
                sum_loss_dcm += Dz_err + Di_err
                sum_loss_gen += Gi_err + Ez_err

                if i % 10 == 0 and False:
                    process_iteration.display_current_results(
                        i, num_batch, {
                            'reconstruction': sum_loss_rest / (i + 1),
                            'discriminator': sum_loss_dcm / (i + 1),
                            'generator': sum_loss_gen / (i + 1),
                        })

            # In end of epoch, summary the loss
            average_loss_per_epoch = np.mean(np.array(training_iteration_loss),
                                             axis=0)

            # validation phase
            num_test = dp.valid.num_sample // config.batch_size
            testing_iteration_loss = []
            for batch in range(num_test):
                z_latent = sampler.sampler_switch(config)
                batch_v = dp.valid.next_batch(config.batch_size)
                GPt_err = sess.run(h.loss_q,
                                   feed_dict={
                                       h.x_c: batch_v.c,
                                       h.z_l: z_latent,
                                       h.z_e: batch_v.e,
                                       h.p_in: batch_v.rd,
                                       h.p_ot: batch_v.q,
                                   })
                Pt_err = sess.run(h.loss_p,
                                  feed_dict={
                                      h.p_i: batch_v.rd,
                                      h.p_ot: batch_v.q,
                                  })
                testing_iteration_loss.append([GPt_err, Pt_err])
            average_test_loss = np.mean(np.array(testing_iteration_loss),
                                        axis=0)

            average_per_epoch = np.concatenate(
                (average_loss_per_epoch, average_test_loss), axis=0)
            training_epoch_loss.append(average_per_epoch)

            # training loss name
            training_loss_name = [
                'R_err',
                'Ez_err',
                'Gi_err',
                'GE_err',
                'EG_err',
                'Dz_err',
                'Di_err',
                'P_err',
                'GP_err',
                'GPt_err',
                'Pt_err',
            ]

            if epoch % 10 == 0:
                process.format_meter(
                    epoch, max_epoch, {
                        'R_err': average_per_epoch[0],
                        'Ez_err': average_per_epoch[1],
                        'Gi_err': average_per_epoch[2],
                        'GE_err': average_per_epoch[3],
                        'EG_err': average_per_epoch[4],
                        'Dz_err': average_per_epoch[5],
                        'Di_err': average_per_epoch[6],
                        'P_err': average_per_epoch[7],
                        'GP_err': average_per_epoch[8],
                        'GPt_err': average_per_epoch[9],
                        'Pt_err': average_per_epoch[10],
                    })

            if (epoch % 1000 == 0 or epoch == max_epoch - 1) and epoch != 0:
                saver.save(sess,
                           os.path.join(config.ckpt_path, 'model_checkpoint'),
                           global_step=epoch)
                pickle_save(training_epoch_loss, training_loss_name, path)
                copy_file(path, config.history_train_path)