Example #1
0
def do_first_em_cycle(pars_init, obs_scheme, y, u=None, eps=1e-30, 
                           use_A_flag=True,use_B_flag=False,diag_R_flag=True):
    """ INPUT:
        pars:      collection of initial parameters for LDS
        obs_scheme: observation scheme for given data, stored in dictionary
                   with keys 'sub_pops', 'obs_time', 'obs_pops'
        y:         data array of observed variables
        u:         data array of input variables
        eps:       precision (stopping criterion) for deciding on convergence
                   of latent covariance estimates durgin the E-step                   
        This function serves to quickly get the results of one EM-cycle. It is
        mostly intended to generate results that can quickly be compared with
        other EM implementations or different parameter initialisation methods. 
    """

    y_dim = y.shape[0]
    u_dim = ssm_fit._get_u_dim(u)
    t_tot = y.shape[1]

    pars_init = dict(pars_init)
    ssm_fit.check_pars(pars=pars_init, 
                       y_dim=y_dim, 
                       u_dim=u_dim)         
    obs_scheme = dict(obs_scheme)
    ssm_fit.check_obs_scheme(obs_scheme=obs_scheme,
                             y_dim=y_dim,
                             t_tot=t_tot)                                   

    # do one E-step
    stats_init, ll_init, t_conv_ft, t_conv_sm    = \
      ssm_fit.lds_e_step(pars=pars_init,
                         y=y,
                         u=u, 
                         obs_scheme=obs_scheme,
                         eps=eps)        

    # do one M-step      
    pars_first = ssm_fit.lds_m_step(stats=stats_init,
                                    y=y, 
                                    u=u,
                                    obs_scheme=obs_scheme,
                                    use_A_flag=use_A_flag,
                                    use_B_flag=use_B_flag,
                                    diag_R_flag=diag_R_flag)

    # do another E-step
    stats_first, ll_first, t_conv_ft, t_conv_sm = \
      ssm_fit.lds_e_step(pars=pars_first,
                         y=y,
                         u=u, 
                         obs_scheme=obs_scheme,
                         eps=eps)        

    return stats_init, ll_init, pars_first, stats_first, ll_first
Example #2
0
def run(x_dim, 
        y_dim, 
        u_dim, 
        t_tot, 
        obs_scheme=None, 

        max_iter=10, 
        epsilon=np.log(1.001),
        eps_cov=0,        
        plot_flag=False,
        trace_pars_flag=False,
        trace_stats_flag=False,
        diag_R_flag=True,
        use_A_flag=True,
        use_B_flag=False,

        pars_true=None,                          
        gen_A_true='diagonal', 
        lts_true=None,
        gen_B_true='random', 
        gen_Q_true='identity', 
        gen_mu0_true='random', 
        gen_V0_true='identity', 
        gen_C_true='random', 
        gen_d_true='scaled', 
        gen_R_true='fraction',

        pars_init=None,
        gen_A_init='diagonal', 
        lts_init=None,
        gen_B_init='random',  
        gen_Q_init='identity', 
        gen_mu0_init='random', 
        gen_V0_init='identity', 
        gen_C_init='random', 
        gen_d_init='mean', 
        gen_R_init='fractionObserved',

        u=None, input_type='pwconst',const_input_length=1,
        y=None, 
        x=None,
        interm_store_flag = False,
        save_file='LDS_data.mat'):

    """ INPUT:
        x_dim : dimensionality of latent states x
        y_dim : dimensionality of observed states y
        u_dim : dimensionality of input states u
        t_tot : trial length (in number of time points)
        obs_scheme: observation scheme for given data, stored in dictionary
                   with keys 'sub_pops', 'obs_time', 'obs_pops'
        max_iter: maximum number of allowed EM steps
        epsilon:    precision (stopping criterion) for deciding on convergence
                    of overall EM algorithm
        eps_cov: precision (stopping criterion) for deciding on convergence
                    of latent covariance estimates durgin the E-step        
        plot_flag   : boolean specifying if to visualise fitting progress`
        trace_pars_flag:  boolean, specifying if entire parameter updates 
                           history or only the current state is kept track of 
        trace_stats_flag: boolean, specifying if entire history of inferred  
                           latents or only the current state is kept track of 
        diag_R_flag      : boolean specifying if R is represented as diagonal
        use_A_flag  : boolean specifying whether to fit the LDS with parameter A
        use_B_flag  : boolean specifying whether to fit the LDS with parameter B

        pars_true : None, or list/np.ndarray/dict containing no, some or all
                   of the desired ground-truth parameters. Will identify any
                   parameters not handed over and will fill in the rest
                   according to selected strings below.
        gen_A_true   : string specifying methods of parameter generation
        lts_true    : ndarray with one entry per latent time scale (i.e. x_dim)
        gen_B_true   : string specifying methods of parameter generation
        gen_Q_true   :  ""
        gen_mu0_true :  "" 
        gen_C_true   :  ""
        gen_V0_true   :  "" 
        gen_d_true   :  ""
        gen_R_true   : (see below for details)
        pars_init : None, or list/np.ndarray/dict containing no, some or all
                   of the desired parameter initialisations. Will identify any
                   parameters not handed over and will fill in the rest
                   according to selected strings below.
        gen_A_init   : string specifying methods of parameter initialisation
        lts_init    : ndarray with one entry per latent time scale (i.e. x_dim)
        gen_B_init   : string specifying methods of parameter initialisation
        gen_Q_init   :  ""
        gen_mu0_init :  "" 
        gen_V0_init  :  ""
        gen_C_init   :  "" 
        gen_d_init   :  ""
        gen_R_init   : (see below for details)
        x: data array of latent variables
        y: data array of observed variables
        u: data array of input variables
        interm_store_flag : boolean, specifying whether or not to 
                                     store the intermediate results after 
                                     each EM cycle to the same folder as
                                     given by input variable save_file
        save_file : (path to folder and) name of file for storing results.  
        Generates parameters of an LDS, potentially by looking at given data.
        Can be used for for generating ground-truth parameters for generating
        data from an artificial experiment using the LDS, or for finding 
        parameter initialisations for fitting an LDS to data. Usage is slightly
        different in the two cases (see below). By nature of the wide range of
        applicability of the LDS model, this function contains many options
        (implemented as strings differing different cases, and arrays giving
         user-specified values such as timescale ranges), and is to be extended
        even further in the future.

    """
    if not isinstance(use_B_flag,bool):
        raise Exception('use_B_flag has to be a boolean. However, it is', use_B_flag)

    if not isinstance(use_A_flag,bool):
        raise Exception('use_A_flag has to be a boolean. However, it is', use_A_flag)

    if not isinstance(diag_R_flag,bool):
        raise Exception('diag_R_flag has to be a boolean. However, it is ',
                         diag_R_flag)

    if not isinstance(interm_store_flag,bool):
        raise Exception(('interm_store_flag has to be a boolean' 
                         'However, it is '), interm_store_flag)
       
    obs_scheme = ssm_fit.check_obs_scheme(obs_scheme=obs_scheme,
                                          y_dim=y_dim,
                                          t_tot=t_tot)

    if y is None:
        if lts_true is None:
            lts_true = np.linspace(0.9,0.98,x_dim)
        pars_true, pars_options_true = gen_pars(
                          x_dim=x_dim, 
                          y_dim=y_dim, 
                          u_dim=u_dim, 
                          pars_in=pars_true,
                          obs_scheme=obs_scheme,
                          gen_A=gen_A_true, 
                          lts=lts_true,
                          gen_B=gen_B_true, 
                          gen_Q=gen_Q_true, 
                          gen_mu0=gen_mu0_true, 
                          gen_V0=gen_V0_true, 
                          gen_C=gen_C_true,
                          gen_d=gen_d_true, 
                          gen_R=gen_R_true)
        n_tr = 1 # fix to always just one repetition for now        

        # generate data from model
        print('generating data from model with ground-truth parameters')
        x,y,u = sim_data(pars=pars_true,
                         t_tot=t_tot,
                         n_tr=n_tr,
                         obs_scheme=obs_scheme,
                         u=u,
                         input_type=input_type,
                         const_input_length=const_input_length)

        Pi   = sp.linalg.solve_discrete_lyapunov(pars_true['A'], 
                                                 pars_true['Q'])
        Pi_t = np.dot(pars_true['A'].transpose(), Pi)

        stats_true, lltr, t_conv_ft, t_conv_sm = \
                        do_e_step(pars=pars_true, 
                                  y=y, 
                                  u=u, 
                                  obs_scheme=obs_scheme, 
                                  eps=eps_cov)

    else:  # i.e. if data provided
        pars_true = {}
        pars_true['A'] = 0
        pars_true['B'] = 0
        pars_true['Q'] = 0
        pars_true['mu0'] = 0
        pars_true['V0'] = 0
        pars_true['C'] = 0
        pars_true['d'] = 0
        pars_true['R'] = 0
        Pi   = 0
        Pi_t = 0
        ext_true = 0
        extxt_true = 0
        extxtm1_true = 0
        lltr = 0


    # get initial parameters
    if not use_A_flag: # overwrites any other parameter choices for A! Set A = 0 
        if isinstance(pars_init, dict) and ('A' in pars_init):
            pars_init['A'] = np.zeros((x_dim,x_dim))
        elif (isinstance(pars_init,(list,np.ndarray)) and 
              not pars_init[0] is None):
            iniPars[0] = np.zeros((x_dim,x_dim))
        elif not gen_A_init == 'zero': 
            print(('Warning: set flag use_A_flag=False, but did not set gen_A_init '
                   'to zero. Will overwrite gen_A_init to zero now.'))
            gen_A_init = 'zero'
    if not use_B_flag: # overwrites any other parameter choices for B! Set B = 0
        if isinstance(pars_init, dict) and ('B' in pars_init):
            pars_init['B'] = np.zeros((x_dim,u_dim))
        elif (isinstance(pars_init,(list,np.ndarray)) and 
              not pars_init[1] is None):
            iniPars[1] = np.zeros((x_dim,u_dim))
        elif not gen_B_init == 'zero': 
            print(('Warning: set flag ifBseA=False, but did not set gen_B_init '
                   'to zero. Will overwrite gen_B_init to zero now.'))
            gen_B_init = 'zero'

    if lts_init is None:
        lts_init = np.random.uniform(size=[x_dim])

    pars_init, pars_options_init = gen_pars(
                      x_dim=x_dim, 
                      y_dim=y_dim, 
                      u_dim=u_dim, 
                      pars_in=pars_init, 
                      obs_scheme=obs_scheme,
                      gen_A=gen_A_init, 
                      lts=lts_init,
                      gen_B=gen_B_init, 
                      gen_Q=gen_Q_init, 
                      gen_mu0=gen_mu0_init, 
                      gen_V0=gen_V0_init, 
                      gen_C=gen_C_init,
                      gen_d=gen_d_init, 
                      gen_R=gen_R_init,
                      x=x, y=y, u=u)

    # check initial goodness of fit for initial parameters
    stats_init,ll_init,pars_first,stats_first,ll_first = do_first_em_cycle(
                                                          pars_init=pars_init, 
                                                          obs_scheme=obs_scheme, 
                                                          y=y, 
                                                          u=u, 
                                                          eps=eps_cov,
                                                          use_A_flag=use_A_flag,
                                                          use_B_flag=use_B_flag,
                                                          diag_R_flag=diag_R_flag)
    if interm_store_flag:
        save_file_interm = save_file
    else: 
        save_file_interm = None


    fit_lds = setup_fit_lds(y=y, 
                            u=y, 
                            max_iter=max_iter,
                            epsilon=epsilon, 
                            eps_cov=eps_cov,
                            plot_flag=plot_flag, 
                            trace_pars_flag=trace_pars_flag, 
                            trace_stats_flag=trace_stats_flag, 
                            diag_R_flag=diag_R_flag,
                            use_A_flag=use_A_flag, 
                            use_B_flag=use_B_flag)

    # fit the model to data          
    print('fitting model to data')
    t = time.time()
    pars_hat,ll = fit_lds(x_dim=x_dim,
                          pars=pars_init, 
                          obs_scheme=obs_scheme,
                          save_file=save_file_interm)

    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)


    stats_hat = ssm_fit._setup_stats(y, x_dim, u_dim)
    if use_B_flag:
        stats_hat, ll_hat, t_conv_ft, t_conv_sm = \
         ssm_fit.lds_e_step(pars=pars_hat,
                            y=y,
                            u=u, 
                            obs_scheme=obs_scheme,
                            eps=eps_cov)

    else:
        stats_hat, ll_hat, t_conv_ft, t_conv_sm = \
         ssm_fit.lds_e_step(pars=pars_hat,
                            y=y,
                            u=None, 
                            obs_scheme=obs_scheme,
                            eps=eps_cov)            

    Pi_h   = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                               pars_hat['Q'])
    Pi_t_h  = np.dot(pars_hat['A'].transpose(), Pi_h)

    # save results for visualisation (with Matlab code)
    if u is None:
        u = 0
        B_h = 0
        B_hs = [0]
    if B_h is None:
        B_h = 0
        B_hs = [0]
    save_file_m = {'x': x, 'y': y, 'u' : u, 'll' : ll, 
                      'T' : t_tot, 'Trial':n_tr, 'elapsedTime' : elapsed_time,
                      'inputType' : input_type,
                      'constInputLngth' : const_input_length,
                      'ifUseB':use_B_flag, 'ifUseA':use_A_flag, 
                      'epsilon':epsilon,
                      'ifPlotProgress':plot_flag,
                      'ifTraceParamHist':trace_pars_flag,
                      'ifTraceStatsHist':trace_stats_flag,
                      'ifRDiagonal':diag_R_flag,
                      'ifUseB':use_B_flag,
                      'covConvEps':eps_cov,        
                      'truePars':pars_true,
                      'initPars':pars_init,
                      'firstPars':pars_first,
                      'estPars': pars_hat,
                      'stats_0': stats_init,
                      'stats_1': stats_first,
                      'stats_h': stats_hat,
                      'stats_true': stats_true,
                      'Pi':Pi,'Pi_h':Pi_h,'Pi_t':Pi_t,'Pi_t_h': Pi_t_h,
                      'obsScheme' : obs_scheme}

    savemat(save_file,save_file_m) # does the actual saving

    return y,x,u,pars_hat,pars_init,pars_true