def plot_2Dquiver_paths(self,
                            session,
                            Xdata,
                            Xvar_name='X:0',
                            scope="",
                            rlt_dir=TEST_DIR + addDateTime() + '/',
                            rslt_file='quiver_plot',
                            with_inflow=False,
                            savefig=False,
                            draw=False,
                            pause=False,
                            skipped=1,
                            feed_range=True,
                            range_xs=20.0,
                            Id=0):
        """
        Plots a superposition of the 2D quiver plot and the paths in latent
        space. Useful to check that the trajectories roughly follow the
        dynamics.
        """
        if savefig:
            if not os.path.exists(rlt_dir): os.makedirs(rlt_dir)
            rslt_file = rlt_dir + rslt_file

        import matplotlib.pyplot as plt
        axes = self.plot2D_sampleX(Xdata,
                                   pause=pause,
                                   draw=draw,
                                   newfig=True,
                                   skipped=skipped)
        if feed_range:
            x1range, x2range = axes.get_xlim(), axes.get_ylim()
            s = int(5 * max(
                abs(x1range[0]) + abs(x1range[1]),
                abs(x2range[0]) + abs(x2range[1])) / 3)
        else:
            x1range = x2range = (-range_xs, range_xs)
            s = int(5 * max(
                abs(x1range[0]) + abs(x1range[1]),
                abs(x2range[0]) + abs(x2range[1])) / 3)

        self.quiver2D_flow(session,
                           Xvar_name=Xvar_name,
                           scope=scope,
                           pause=pause,
                           x1range=x1range,
                           x2range=x2range,
                           scale=s,
                           newfig=False,
                           with_inflow=with_inflow,
                           draw=draw,
                           Id=Id)
        if savefig:
            plt.savefig(rslt_file)
        else:
            pass
        plt.close()
Exemple #2
0
def create_RLT_DIR(Experiment_params):
    # create the dir to save data
    # Experiment_params is a dict containing param_name&param pair
    # Experiment_params must contain "rslt_dir_name":rslt_dir_name
    cur_date = addDateTime()

    local_rlt_root = './rslts/' + Experiment_params['rslt_dir_name'] + '/'

    params_str = ""
    for param_name, param in Experiment_params.items():
        if param_name == 'rslt_dir_name':
            continue
        params_str += param_name + '_' + str(param) + '_'

    RLT_DIR = local_rlt_root + params_str + cur_date + '/'

    if not os.path.exists(RLT_DIR): os.makedirs(RLT_DIR)

    return RLT_DIR
    def plot_2Dquiver_paths(self,
                            session,
                            Xdata,
                            Xvar_name,
                            rlt_dir=TEST_DIR + addDateTime() + '/',
                            rslt_file='quiver_plot',
                            with_inflow=False,
                            savefig=False,
                            draw=False,
                            pause=False):
        """
        """
        if savefig:
            if not os.path.exists(rlt_dir): os.makedirs(rlt_dir)
            rslt_file = rlt_dir + rslt_file

        import matplotlib.pyplot as plt
        axes = self.plot2D_sampleX(Xdata, pause=pause, draw=draw, newfig=True)
        x1range, x2range = axes.get_xlim(), axes.get_ylim()
        s = int(5 * max(
            abs(x1range[0]) + abs(x1range[1]),
            abs(x2range[0]) + abs(x2range[1])) / 3)

        self.quiver2D_flow(session,
                           Xvar_name,
                           pause=pause,
                           x1range=x1range,
                           x2range=x2range,
                           scale=s,
                           newfig=False,
                           with_inflow=with_inflow,
                           draw=draw)
        if savefig:
            plt.savefig(rslt_file)
        else:
            pass
        plt.close()
Exemple #4
0
import sys
import os
from optparse import OptionParser
import cPickle as pickle

import numpy as np
import theano
import lasagne
from lasagne.nonlinearities import softmax, linear, softplus

# Local imports
from LatEvModels import LocallyLinearEvolution
from ObservationModels import GaussianObsTSGM, PoissonObsTSGM

from datetools import addDateTime
cur_date = addDateTime()

# ==============================================================================

LOCAL_RLT_ROOT = "/Users/danielhernandez/Work/time_series/vae_nlds_rec_algo_v2/data/"
RLT_DIR = "poisson_data_008/"
LOAD_DIR = LOCAL_RLT_ROOT + RLT_DIR  # "/Users/danielhernandez/Work/time_series/vae_nlds_rec_algo_v2/data/gaussian_data_002/"
LOAD_FILE = "datadict"

YDIM = 10
XDIM = 2
TBINS = 30
NSAMPS = 100
OBSERVATIONS = 'Poisson'
NNODES = 60
ALPHA = 0.0