コード例 #1
0
ファイル: plot_residuals.py プロジェクト: remtcs/theano_pyglm
def make_weight_residual_plot(N, dt, s_infs_mcmc, s_trues, s_infs_map=None, resdir='.'):
    # Scatter plot the average weights vs true weights
    fig = plt.figure(figsize=(2.5,2.5))
    ax = fig.gca()

    for d,s_inf_mcmc in enumerate(s_infs_mcmc):
        s_avg = average_list_of_dicts(s_inf_mcmc)
        label = None if d >0 else "N-GLM"
        scatter_plot_weight_residuals(dt, s_avg, s_trues[d], ax=ax, sz=20, color='r', label=label)

    if s_infs_map is not None:
        for d, s_inf_map in enumerate(s_infs_map):
            label = None if d >0 else "L1-GLM"
            scatter_plot_weight_residuals(dt, s_inf_map, s_trues[d], ax=ax, sz=15, color='b', label=label)

    # Plot identity line
    xlim = ax.get_xlim()
    ax.plot(xlim, xlim, '-k', linewidth=0.5)

    ax.legend(loc='upper left', prop={'size' : 8}, scatterpoints=1)

    # Make room for the axis labels
    plt.subplots_adjust(left=0.25, bottom=0.25)

    # Save and close
    fig.savefig(os.path.join(resdir, 'W_resid.pdf'))
    plt.close(fig)
コード例 #2
0
def make_weight_residual_plot(N,
                              dt,
                              s_infs_mcmc,
                              s_trues,
                              s_infs_map=None,
                              resdir='.'):
    # Scatter plot the average weights vs true weights
    fig = plt.figure(figsize=(2.5, 2.5))
    ax = fig.gca()

    for d, s_inf_mcmc in enumerate(s_infs_mcmc):
        s_avg = average_list_of_dicts(s_inf_mcmc)
        label = None if d > 0 else "N-GLM"
        scatter_plot_weight_residuals(dt,
                                      s_avg,
                                      s_trues[d],
                                      ax=ax,
                                      sz=20,
                                      color='r',
                                      label=label)

    if s_infs_map is not None:
        for d, s_inf_map in enumerate(s_infs_map):
            label = None if d > 0 else "L1-GLM"
            scatter_plot_weight_residuals(dt,
                                          s_inf_map,
                                          s_trues[d],
                                          ax=ax,
                                          sz=15,
                                          color='b',
                                          label=label)

    # Plot identity line
    xlim = ax.get_xlim()
    ax.plot(xlim, xlim, '-k', linewidth=0.5)

    ax.legend(loc='upper left', prop={'size': 8}, scatterpoints=1)

    # Make room for the axis labels
    plt.subplots_adjust(left=0.25, bottom=0.25)

    # Save and close
    fig.savefig(os.path.join(resdir, 'W_resid.pdf'))
    plt.close(fig)
コード例 #3
0
def plot_results(population, 
                 x_inf, 
                 popn_true=None, 
                 x_true=None, 
                 resdir=None,
                 do_plot_connectivity=True,
                 do_plot_stim_resp=True,
                 do_plot_imp_responses=True,
                 do_plot_firing_rates=True,
                 do_plot_ks=True,
                 do_plot_logpr=True):
    """ Plot the inferred stimulus tuning curves and impulse responses
    """
    if not resdir:
        resdir = '.'

    true_given = x_true is not None and popn_true is not None
    
    # Make sure we have a list of x's
    if not isinstance(x_inf, list):
        x_inf = [x_inf]

    # Evaluate the state for each of the parameter settings
    N_samples = len(x_inf)
    s_inf = []
    for x in x_inf:
        s_inf.append(population.eval_state(x))
    
    s_true = None
    if true_given:
        s_true = popn_true.eval_state(x_true)

    # Average the inferred states
    s_avg = average_list_of_dicts(s_inf)
    s_std = std_list_of_dicts(s_inf, s_avg)
    N = population.N

    # TODO Fix the averaging of W and A
    # E[W] * E[A] != E[W*A]
    # Plot the inferred connectivity matrix
    if do_plot_connectivity:
        print "Plotting connectivity matrix"
        f = plt.figure()
        plot_connectivity_matrix(s_inf, s_true)
        f.savefig(os.path.join(resdir,'conn.pdf'))
        plt.close(f)

    if 'location_provider' in s_inf[0]['latent']:
        f = plt.figure()
        plot_locations(s_inf, color='r')
        if true_given:
            plot_locations([s_true], color='k')
        f.savefig(os.path.join(resdir, 'locations.pdf'))
        plt.close(f)


    # Plot shared tuning curves
    if 'sharedtuningcurve_provider' in s_inf[0]['latent']:
        print "Plotting shared tuning curves"
        for n in range(N):
            f = plt.figure()
            if true_given:
                plot_spatiotemporal_tuning_curves(
                    s_avg,
                    s_true=s_true,
                    s_std=s_std,
                    color='r')
            else:
                plot_spatiotemporal_tuning_curves(
                    s_avg,
                    s_std=s_std,
                    color='k')

            f.savefig(os.path.join(resdir,'tuning_curves.pdf'))
            plt.close(f)

        print "Plotting types"
        f = plt.figure()
        plot_latent_types(s_inf, s_true)
        f.savefig(os.path.join(resdir, 'latent_types.pdf'))
        plt.close(f)

    # Plot stimulus response functions
    if do_plot_stim_resp:
        print "Plotting stimulus response functions"
        for n in range(N):
            f = plt.figure()
            plot_stim_response(s_avg['glms'][n], 
                               s_glm_std=s_std['glms'][n],
                               color='r')
            if true_given:
                plot_stim_response(s_true['glms'][n], 
                                   color='k')
        
            f.savefig(os.path.join(resdir,'stim_resp_%d.pdf' % n))
            plt.close(f)
        
    # Plot the impulse responses
    if do_plot_imp_responses:
        print "Plotting impulse response functions"
        f = plt.figure()
        plot_imp_responses(s_avg,
                           s_std,
                           fig=f,
                           color='r',
                           use_bgcolor=True)
        if true_given:
            plot_imp_responses(s_true,
                               fig=f,
                               color='k',
                               linestyle='--',
                               use_bgcolor=False)
            
        f.savefig(os.path.join(resdir,'imp_resp.pdf'))
        plt.close(f)
    
    # Plot the impulse response basis
    if do_plot_imp_responses:
        f = plt.figure()
        plot_basis(s_avg)
        f.savefig(os.path.join(resdir,'imp_basis.pdf'))
        plt.close(f)
    

    # Plot the firing rates
    if do_plot_firing_rates:
        print "Plotting firing rates"
        T_lim = slice(0,2000)
        for n in range(N):
            f = plt.figure()
            plot_firing_rate(s_avg['glms'][n], 
                             s_std['glms'][n], 
                             color='r',
                             T_lim=T_lim)
            if true_given:
                plot_firing_rate(s_true['glms'][n], color='k', T_lim=T_lim)
            
            # Plot the spike times
            St = np.nonzero(population.glm.S.get_value()[T_lim,n])[0]
            plt.plot(St,s_avg['glms'][n]['lam'][T_lim][St],'ko')
            
            plt.title('Firing rate %d' % n)
            
            f.savefig(os.path.join(resdir,'firing_rate_%d.pdf' % n))
            plt.close(f)

    if do_plot_ks:
        print "Plotting KS test results"
        for n in range(N):
            f = plt.figure()
            St = np.nonzero(population.glm.S.get_value()[:,n])[0]
            plot_ks(s_avg['glms'][n], St, population.glm.dt.get_value())
            f.savefig(os.path.join(resdir, 'ks_%d.pdf' %n))
            plt.close(f)

    if do_plot_logpr:
        print "Plotting log probability and log likelihood trace"
        f = plt.figure()
        plot_log_prob(s_inf, s_true=s_true, color='r')
        f.savefig(os.path.join(resdir, 'log_prob.pdf'))
        plt.close(f)
        
        f = plt.figure()
        plot_log_lkhd(s_inf, s_true=s_true, color='r')
        f.savefig(os.path.join(resdir, 'log_lkhd.pdf'))
        plt.close(f)

        if 'logprior' in s_inf[0]:
            f = plt.figure()
            plot_log_prob(s_inf, key='logprior', s_true=s_true, color='r')
            plt.ylabel('Log prior')
            f.savefig(os.path.join(resdir, 'log_prior.pdf'))
            plt.close(f)

        if 'predll' in x_inf[0]:
            f = plt.figure()
            plot_log_prob(x_inf, key='predll', s_true=x_true, color='r')
            plt.ylabel('Pred. Log Likelihood')
            f.savefig(os.path.join(resdir, 'pred_ll.pdf'))
            plt.close(f)

    print "Plots can be found in directory: %s" % resdir