示例#1
0
def compare_indices(indices_name, sub_number=1, description=None):
    """
    Comparision of different parameters
    Keyword arguments:
        indices_name, sub_number: physical process name and order
        description: the name of plot title

    return a figure save into opt.t_p.
    """
    if indices_name == 'Q':
        if sub_number > 1:
            raise Exception(
                'the number you want don not exist, it should be 0 or 1')
    elif indices_name == 'r':
        if sub_number > 6:
            raise Exception(
                'the number you want don not exist, it should be 0 - 6')
    else:
        Exception('the indices name you want is not valid,try again')

    # obtain file names
    file_names = os.listdir(opt.history_test_path)
    set = []
    name = []
    # baseline
    path_name = os.path.join(opt.history_test_path, file_names[0])
    ensemble, _ = pickle_load(path_name)
    set.append(ensemble['actual_' + indices_name][:, sub_number])
    name.append('origin')

    # collection of result
    for file_name in file_names:
        path = os.path.join(opt.history_test_path, file_name)
        ensemble, _ = pickle_load(path)
        set.append(ensemble['result_' + indices_name][:, sub_number])
        name.append(file_name)

    # plot
    set = np.array(set).T
    name = [item.replace('.pkl', '') for item in name]
    vis.mtsplot(set, name, 'the comparision of ' + description,
                opt.history_result_path)
    pylab.close()
import seaborn as sns

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../")))
from data_providers.data_provider import SuvsDataProvider
from plot import Visualizer
from sequential.utils import pickle_load

vis = Visualizer()
dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch')
config.is_train = False
config.batch_size = dp.valid.num_sample
test_times = 60

# Gaussian mixture
description1 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-gm-bc-gs-2018-02-04-metric.pkl'
[q_errors1, r_adjs1, z_adjs1], name1 = pickle_load(description1)

# Swiss Roll
description2 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-sr-bc-gs-2018-02-04-metric.pkl'
[q_errors2, r_adjs2, z_adjs2], name2 = pickle_load(description2)

# Uniform Desk
description3 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-ud-bc-gs-2018-02-04-metric.pkl'
[q_errors3, r_adjs3, z_adjs3], name3 = pickle_load(description3)

# Uniform Square
description4 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-us-bc-gs-2018-02-04-metric.pkl'
[q_errors4, r_adjs4, z_adjs4], name4 = pickle_load(description4)

# Gaussian
description5 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-gs-bc-gs-2018-02-04-metric.pkl'
示例#3
0
        vis.cplot(actual_Q[:, 1], result_Q[:, 1], ['Q2', 'origin', 'modify'],
                  config.t_p)
        for num in range(6):
            vis.cplot(actual_r[:, num], result_r[:, num],
                      ['R{}'.format(num + 1), 'origin', 'modify'], config.t_p)


if __name__ == "__main__":
    # test result
    main()

    # process loss
    path = os.path.join(config.logs_path, config.description + '-train.pkl')

    logger.info("{}".format(path))
    hist_value, hist_head = pickle_load(path, use_pd=True)
    for loss_name in [
            'R_err',
            'GE_err',
            'EG_err',
            'GPt_err',
            'GP_err',
            'Pt_err',
            'P_err',
    ]:
        vis.tsplot(hist_value[loss_name], loss_name, config.loss_path)

    vis.dyplot(hist_value['Dz_err'], hist_value['Ez_err'], 'z',
               config.loss_path)
    vis.dyplot(hist_value['Di_err'], hist_value['Gi_err'], 'img',
               config.loss_path)
示例#4
0
from plot import Visualizer
import numpy as np
from sequential.utils import pickle_load
from base_options import BaseConfig
import matplotlib.pyplot as plt
import matplotlib as mpl

vis = Visualizer()
opt = BaseConfig()

# load data
# path = os.path.join(opt.logs_path, opt.description + '-test.pkl')
path = 'logs/hybrid_GAN_cc_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-gs-bc-gs-2018-02-04-test.pkl'
# path = 'logs/hybrid_GAN_cc_vgg-Dz=0.01-R=10-Lat=0.5-Tv=0.01-x-gs-bc-gs-2018-02-04-test.pkl'
# path = 'logs/hybrid_GAN_cc_res-Dz=0.01-R=10-Lat=0.5-Tv=0.01-x-gs-bc-gs-2018-02-04-test.pkl'
d, _ = pickle_load(path)
t = np.arange(0, len(d["actual_Q"][:, 0])) + 1
"""
Plot 3x3 Sub-figures
"""
# GLOBAL SETTING
# FILE PATH: C:\Program Files\Anaconda3\Lib\site-packages\matplotlib\mpl-data\stylelib
plt.style.use('classic-znz')
# mpl.rcParams['lines.linewidth'] = 1
mpl.rc('lines', linewidth=1, markersize=4)
mpl.rc('font', family='Times New Roman')
mpl.rcParams['legend.labelspacing'] = 0.05


def animate(axs, tt):
    ylim = axs.get_ylim()
示例#5
0
        r_adjs = np.array(r_adjs).reshape(-1, config.ndim_x)
        z_adjs = np.array(z_adjs).reshape(-1, config.ndim_z)

        pickle_save([q_errors, r_adjs, z_adjs, z_true],
                    ["productions", "adjustment", "latent_variables"],
                    '{}/{}-metric_plus.pkl'.format(config.logs_path,
                                                   description))


if __name__ == "__main__":
    # main
    main()

    file_path = '{}/{}-metric_plus.pkl'.format(config.logs_path,
                                               config.description)
    [q_errors, r_adjs, z_adjs, z_true], name = pickle_load(file_path)
    # Load data

    # latent space
    label = np.tile(np.array(dp.test.e), [test_times, 1])
    scatter_labeled_z(z_adjs,
                      label,
                      dir=config.latent_path,
                      filename='latent-Q2')
    vis.kdeplot(z_adjs[:, 0],
                z_adjs[:, 1],
                config.latent_path,
                name='Kde_Latent_Space')
    vis.jointplot(z_adjs[:, 0],
                  z_adjs[:, 1],
                  config.latent_path,
示例#6
0
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../")))
from data_providers.data_provider import SuvsDataProvider
from plot import Visualizer
from sequential.utils import pickle_load

vis = Visualizer()
dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch')
config.is_train = False
config.batch_size = dp.valid.num_sample
test_times = 60

# Gaussian mixture
description1 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-gm-bc-gs-2018-02-04-metric_plus3.pkl'
[q_errors1, r_adjs1, z_adjs1, z_true1,
 z_train1], name1 = pickle_load(description1)

# Swiss Roll
description2 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-sr-bc-gs-2018-02-04-metric_plus3.pkl'
[q_errors2, r_adjs2, z_adjs2, z_true2,
 z_train2], name2 = pickle_load(description2)

# Uniform Desk
description3 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-ud-bc-gs-2018-02-04-metric_plus3.pkl'
[q_errors3, r_adjs3, z_adjs3, z_true3,
 z_train3], name3 = pickle_load(description3)

# Uniform Square
description4 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-us-bc-gs-2018-02-04-metric_plus3.pkl'
[q_errors4, r_adjs4, z_adjs4, z_true4,
 z_train4], name4 = pickle_load(description4)
示例#7
0
def main(run_load_from_file=False):
    config = BaseConfig()
    config.folder_init()
    dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch')
    max_epoch = 500
    batch_size_l = config.batch_size
    path = os.path.join(config.logs_path, config.description + '-train.pkl')

    # training
    with tf.device(config.device):
        h = build_graph()

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)
    sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9
    saver = tf.train.Saver(max_to_keep=2)

    with tf.Session(config=sess_config) as sess:
        '''
         Load from checkpoint or start a new session

        '''
        if run_load_from_file:
            saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path))
            training_epoch_loss, _ = pickle_load(path)
        else:
            sess.run(tf.global_variables_initializer())
            training_epoch_loss = []

        # Recording loss per epoch
        process = Process()
        lr_schedule = create_lr_schedule(lr_base=2e-4,
                                         decay_rate=0.1,
                                         decay_epochs=500,
                                         truncated_epoch=2000,
                                         mode=config.lr_schedule)
        for epoch in range(max_epoch):
            process.start_epoch()
            '''
            Learning rate generator

            '''
            learning_rate = lr_schedule(epoch)
            # Recording loss per iteration
            training_iteration_loss = []
            sum_loss_rest = 0
            sum_loss_dcm = 0
            sum_loss_gen = 0

            process_iteration = Process()
            data_size = dp.train_l.num_sample
            num_batch = data_size // config.batch_size
            for i in range(num_batch + 1):
                process_iteration.start_epoch()
                # Inputs
                # sample from data distribution
                batch_l = dp.train_l.next_batch(batch_size_l)
                z_prior = sampler.sampler_switch(config)
                # adversarial phase for discriminator_z
                _, Dz_err = sess.run([h.opt_dz, h.loss_dz],
                                     feed_dict={
                                         h.x: batch_l.x,
                                         h.z_p: z_prior,
                                         h.lr: learning_rate,
                                     })
                z_latent = sampler.sampler_switch(config)
                _, Di_err = sess.run(
                    [h.opt_dimg, h.loss_dimg],
                    feed_dict={
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.x_s: batch_l.x,
                        h.lr: learning_rate,
                    })
                z_latent = sampler.sampler_switch(config)
                # reconstruction_phase
                _, R_err, Ez_err, Gi_err, GE_err, EG_err = sess.run(
                    fetches=[
                        h.opt_r, h.loss_r, h.loss_e, h.loss_d, h.loss_l,
                        h.loss_eg
                    ],
                    feed_dict={
                        h.x: batch_l.x,
                        h.z_p: z_prior,
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.x_s: batch_l.x,
                        h.lr: learning_rate,
                    })
                # process phase
                _, P_err = sess.run([h.opt_p, h.loss_p],
                                    feed_dict={
                                        h.p_i: batch_l.rd,
                                        h.p_ot: batch_l.q,
                                        h.lr: learning_rate
                                    })
                # push process to normal
                z_latent = sampler.sampler_switch(config)
                _, GP_err = sess.run(
                    [h.opt_q, h.loss_q],
                    feed_dict={
                        h.x_c: batch_l.c,
                        h.z_l: z_latent,
                        h.z_e: batch_l.e,
                        h.p_in: batch_l.rd,
                        h.p_ot: batch_l.q,
                        h.lr: learning_rate,
                    })
                # recording loss function
                training_iteration_loss.append([
                    R_err, Ez_err, Gi_err, GE_err, EG_err, Dz_err, Di_err,
                    P_err, GP_err
                ])
                sum_loss_rest += R_err
                sum_loss_dcm += Dz_err + Di_err
                sum_loss_gen += Gi_err + Ez_err

                if i % 10 == 0 and False:
                    process_iteration.display_current_results(
                        i, num_batch, {
                            'reconstruction': sum_loss_rest / (i + 1),
                            'discriminator': sum_loss_dcm / (i + 1),
                            'generator': sum_loss_gen / (i + 1),
                        })

            # In end of epoch, summary the loss
            average_loss_per_epoch = np.mean(np.array(training_iteration_loss),
                                             axis=0)

            # validation phase
            num_test = dp.valid.num_sample // config.batch_size
            testing_iteration_loss = []
            for batch in range(num_test):
                z_latent = sampler.sampler_switch(config)
                batch_v = dp.valid.next_batch(config.batch_size)
                GPt_err = sess.run(h.loss_q,
                                   feed_dict={
                                       h.x_c: batch_v.c,
                                       h.z_l: z_latent,
                                       h.z_e: batch_v.e,
                                       h.p_in: batch_v.rd,
                                       h.p_ot: batch_v.q,
                                   })
                Pt_err = sess.run(h.loss_p,
                                  feed_dict={
                                      h.p_i: batch_v.rd,
                                      h.p_ot: batch_v.q,
                                  })
                testing_iteration_loss.append([GPt_err, Pt_err])
            average_test_loss = np.mean(np.array(testing_iteration_loss),
                                        axis=0)

            average_per_epoch = np.concatenate(
                (average_loss_per_epoch, average_test_loss), axis=0)
            training_epoch_loss.append(average_per_epoch)

            # training loss name
            training_loss_name = [
                'R_err',
                'Ez_err',
                'Gi_err',
                'GE_err',
                'EG_err',
                'Dz_err',
                'Di_err',
                'P_err',
                'GP_err',
                'GPt_err',
                'Pt_err',
            ]

            if epoch % 10 == 0:
                process.format_meter(
                    epoch, max_epoch, {
                        'R_err': average_per_epoch[0],
                        'Ez_err': average_per_epoch[1],
                        'Gi_err': average_per_epoch[2],
                        'GE_err': average_per_epoch[3],
                        'EG_err': average_per_epoch[4],
                        'Dz_err': average_per_epoch[5],
                        'Di_err': average_per_epoch[6],
                        'P_err': average_per_epoch[7],
                        'GP_err': average_per_epoch[8],
                        'GPt_err': average_per_epoch[9],
                        'Pt_err': average_per_epoch[10],
                    })

            if (epoch % 1000 == 0 or epoch == max_epoch - 1) and epoch != 0:
                saver.save(sess,
                           os.path.join(config.ckpt_path, 'model_checkpoint'),
                           global_step=epoch)
                pickle_save(training_epoch_loss, training_loss_name, path)
                copy_file(path, config.history_train_path)