예제 #1
0
def analysis_and_write(params,
                       weights_path,
                       fig_directory,
                       run_name,
                       no_rec_noise=True):

    from matplotlib.backends.backend_pdf import PdfPages
    import os
    import copy

    original_params = copy.deepcopy(params)

    if no_rec_noise:
        params['rec_noise'] = 0.0

    try:
        os.stat(fig_directory)
    except:
        os.mkdir(fig_directory)

    pp = PdfPages(fig_directory + '/' + run_name + '.pdf')

    generator = generate_train_trials(params)
    weights = np.load(weights_path)

    W = weights['W_rec']
    Win = weights['W_in']
    Wout = weights['W_out']
    brec = weights['b_rec']

    data = generator.next()
    sim = Simulator(params, weights_path=weights_path)
    output, states = sim.run_trial(data[0][0, :, :], t_connectivity=False)

    s = np.zeros([data[0].shape[1], data[0].shape[0], 100])
    for ii in range(data[0].shape[0]):
        s[:, ii, :] = sim.run_trial(data[0][ii, :, :],
                                    t_connectivity=False)[1].reshape(
                                        [data[0].shape[1], 100])

    #Figure 0 (Plot Params)
    fig0 = plot_params(original_params)
    pp.savefig(fig0)

    #Figure 1 (Single Trial (Input Output State))
    fig1 = plot_single_trial(data, states, output)
    pp.savefig(fig1)

    #Figure 2 (Plot structural measures of W against random matrix R)
    fig2 = plot_structure_Wrec(W)
    pp.savefig(fig2)

    #Figure 3 (Stupid Figure where activity is sorted by time of max firing rate)
    fig3 = plt.figure(figsize=(7, 3))
    plot_by_max(s[:, 0, :])
    plt.xlabel('Time')
    plt.ylabel('Neuron')
    pp.savefig(fig3)

    #Figure 4 (Principal Angle Analysis)
    fig4 = plot_principal_angles(W, s, data)
    pp.savefig(fig4)

    #Figure 5 Plot long term state activity for in, and in+go conditions
    fig5 = plot_long_term_state(sim)
    pp.savefig(fig5)

    #Figure 6 Plot PC projection
    fig6 = plot_state_pcs(sim)
    pp.savefig(fig6)

    #Figure 7 Hamming Distance btw Putative Fixed Points
    fig7 = plot_hamming_dist(s, W, brec)
    pp.savefig(fig7)

    #Figure 8 Plot long delayed go cue
    fig8 = plot_long_delayed_go_cue(sim)
    pp.savefig(fig8)

    #Figure 9 Plot angles between input mapping and output mapping
    fig9 = plot_input_output_angles(Win, W, Wout, brec)
    pp.savefig(fig9)

    #Figure 10 Plot biclustered recurrent weights
    fig10 = plot_biclustered_weights(W)
    pp.savefig(fig10)

    #Figure 11 Plot biases
    fig11 = plot_biases(weights)
    pp.savefig(fig11)

    #Figure 12 Plot Eigenspectrum
    fig12 = plot_eig_dist(s[:, 0, :], W)
    pp.savefig(fig12)

    pp.close()
예제 #2
0
    sess = tf.Session()

    print('first training')
    model.train(sess,
                generator,
                learning_rate=learning_rate,
                training_iters=training_iters,
                weights_path=weights_path)
    #print('second training')
    #model.train(sess, generator, learning_rate = learning_rate, training_iters = training_iters, weights_path = weights_path, initialize_variables=False)

    data = generator.next()
    #output,states = model.test(sess, input, weights_path = weights_path)

    W = model.W_rec.eval(session=sess)
    U = model.W_in.eval(session=sess)
    Z = model.W_out.eval(session=sess)
    brec = model.b_rec.eval(session=sess)
    bout = model.b_out.eval(session=sess)

    sim = Simulator(params, weights_path=weights_path)
    output, states = sim.run_trial(data[0][0, :, :], t_connectivity=False)

    s = np.zeros([data[0].shape[1], data[0].shape[0], 50])
    for ii in range(data[0].shape[0]):
        s[:, ii, :] = sim.run_trial(data[0][ii, :, :],
                                    t_connectivity=False)[1].reshape(
                                        [data[0].shape[1], 50])

    sess.close()
예제 #3
0
def analysis_and_write(params,weights_path):
    
    from matplotlib.backends.backend_pdf import PdfPages
    import os
    
    try:
        os.stat('demo_figures')
    except:
        os.mkdir('demo_figures')
        
    pp = PdfPages('demo_figures/demo_analysis_figures.pdf')

    generator = generate_train_trials(params)
    weights = np.load(weights_path)
    
    W = weights['W_rec']
    brec = weights['b_rec'] 
    
    data = generator.next()
    sim = Simulator(params, weights_path=weights_path)
    output,states = sim.run_trial(data[0][0,:,:],t_connectivity=False)
    
    s = np.zeros([data[0].shape[1],data[0].shape[0],100])
    for ii in range(data[0].shape[0]):
        s[:,ii,:] = sim.run_trial(data[0][ii,:,:],t_connectivity=False)[1].reshape([data[0].shape[1],100])
        
    #Figure 1 (Single Trial (Input Output State))
    fig1 = plt.figure(figsize=(5,5))
    plt.subplot(3,1,1)
    plt.plot(output[:,0,:])
    plt.title('Out')
    plt.subplot(3,1,2)
    plt.plot(states[:,0,:])
    plt.title('State')
    plt.subplot(3,1,3)
    plt.plot(data[0][0,:,:])
    plt.title('Input')
    plt.tight_layout()
    
    pp.savefig(fig1)
    
    #Figure 2 (Plot structural measures of W against random matrix R)
    N = W.shape[0]

    R = np.random.randn(N,N)/float(N)
    R = 1.1*R/np.max(np.abs(np.linalg.eig(R)[0]))
    
    #calculate the norm of trained rec matrix W and random gaussian matrix R
    normW = calc_norm(W)
    normR = calc_norm(R)
    min_norm = np.min([np.min(normW),np.min(normR)])
    max_norm = np.max([np.max(normW),np.max(normR)])
    xx_norm = np.linspace(min_norm,max_norm,50)
    histnormW, _ = np.histogram(normW,xx_norm)
    histnormR, _ = np.histogram(normR,xx_norm)
    
    #calculate hists for angles between columns
    
    angle_W = np.arccos(np.clip((W.T.dot(W))/np.outer(normW,normW),-1.,1.))
    angle_R = np.arccos(np.clip((R.T.dot(R))/np.outer(normR,normR),-1.,1.))
    min_val = np.min([np.min(angle_W),np.min(angle_R)])
    max_val = np.max([np.max(angle_W),np.max(angle_R)])
    xx = np.linspace(min_val,max_val,50)
    histW, bin_edgesW = np.histogram(angle_W[np.tril(np.ones_like(W),-1)>0],xx)
    histR, bin_edgesR = np.histogram(angle_R[np.tril(np.ones_like(R),-1)>0],xx)
    
    fig2 = plt.figure(figsize=(8,5))
    plt.subplot(2,2,1)
    plt.pcolormesh(angle_W)
    plt.colorbar()
    plt.title('$\measuredangle$ W')
    
    plt.subplot(2,2,2)
    plt.pcolormesh(angle_R)
    plt.colorbar()
    plt.title('$\measuredangle$ R')
    
    plt.subplot(2,2,3)
    plt.bar(xx[:-1],histW,width=bin_edgesW[1]-bin_edgesW[0])
    plt.bar(xx[:-1],-histR,width=bin_edgesR[1]-bin_edgesR[0],color='g')
    
    plt.legend(['W','Random'],fontsize=10,loc='lower left')
    plt.title('Hist of Angles')
    
    plt.subplot(2,2,4)
    plt.bar(xx_norm[:-1],histnormW,width=xx_norm[1]-xx_norm[0])
    plt.bar(xx_norm[:-1],-histnormR,width=xx_norm[1]-xx_norm[0],color='g')
    
    plt.legend(['W','Random'],fontsize=10,loc='lower left')
    plt.title('Hist of Norms')
    plt.tight_layout()
    
    pp.savefig(fig2)
    
    #Figure 3 (Stupid Figure where activity is sorted by time of max firing rate)
    fig3 = plt.figure(figsize=(7,3))
    plot_by_max(s[:,0,:])
    plt.xlabel('Time')
    plt.ylabel('Neuron')
    
    pp.savefig(fig3)
    
    #Figure 4 (Principal Angle Analysis)
    masks = s[:,0,:].T>0
    max_ev = np.zeros(data[0].shape[1])
    
    pos = []
    neg = []
    leading = []
    for ii in range(data[0].shape[1]):
        evals,evecs = np.linalg.eig(W*masks[:,ii]-np.eye(100))
        max_ev[ii] = np.max(evals.real)
        pos.append(evecs[:,evals>0])
        neg.append(evecs[:,evals<0])
        leading.append(evecs[:,np.argsort(np.abs(evals.real))[:10]]) #.reshape([100,2]))
        
    
    xx = np.arange(0,data[0].shape[1],1)
    pa = np.zeros([len(xx),len(xx)])
    
    basis = leading
    
    for ii,pre in enumerate(xx):
        for jj,post in enumerate(xx):
            if basis[pre].shape[1]*basis[post].shape[1]>0:
                pas = principal_angle(basis[pre],basis[post])
                pa[ii,jj] = np.nanmean(pas)
            else:
                pa[ii,jj] = 0.
    
    fig4 = plt.figure()        
    plt.pcolormesh(pa,vmin=0,vmax=90)
    plt.colorbar()
    plt.ylim([0,pa.shape[0]])
    plt.xlim([0,pa.shape[1]])
    
    plt.title('Principal Angle Analysis')
    plt.xlabel('Time')
    plt.ylabel('Time')
    
    pp.savefig(fig4)

    #Figure 5 Plot long term state activity for in, and in+go conditions
    
    d = .01*np.random.randn(2000,3)
    d[50:60,0] = 1.
    o_in0,s_in0 = sim.run_trial(d,t_connectivity=False)
    
    
    d[50:60,1] = 1.
    o_in1,s_in1 = sim.run_trial(d,t_connectivity=False)
    
    d = .01*np.random.randn(2000,3)
    d[50:60,0] = 1.
    d[150:160,2] = 1.
    o_go0,s_go0 = sim.run_trial(d,t_connectivity=False)
    
    
    d[50:60,1] = 1.
    d[150:160,2] = 1.
    o_go1,s_go1 = sim.run_trial(d,t_connectivity=False)
    
    fig5 = plt.figure(figsize=(8,6))
    
    plt.subplot(4,2,1)
    plt.plot(s_in1[:500,0,:]);
    plt.title('Long Input 1')
    plt.subplot(4,2,3)
    plt.plot(s_in0[:500,0,:]);
    plt.title('Long Input 2')
    plt.subplot(4,2,5)
    plt.plot(s_in0[:500,0,:] - s_in1[:500,0,:]);
    plt.title('Difference')
    plt.subplot(4,2,7)
    plt.plot(o_in1[:500,0,:]);
    plt.plot(o_in0[:500,0,:]);
    plt.title('Output')
    
    
    plt.subplot(4,2,2)
    plt.plot(s_go0[:500,0,:]);
    plt.title('Long Input 1 + Go Cue')
    plt.subplot(4,2,4)
    plt.plot(s_go1[:500,0,:]);
    plt.title('Long Input 2 + Go Cue')
    plt.subplot(4,2,6)
    plt.plot(s_go0[:500,0,:] - s_go1[:500,0,:]);
    plt.title('Difference')
    plt.subplot(4,2,8)
    plt.plot(o_go0[:500,0,:]);
    plt.plot(o_go1[:500,0,:]);
    plt.title('Output')
    
    plt.tight_layout()
    
    pp.savefig(fig5)
    
    #Figure 6 Plot PC projection

    s_pca = np.concatenate((s_go0[:500,0,:],s_go1[:500,0,:]),axis=0)
    s_pca = demean(s_pca)
    c_pca = np.cov(s_pca.T)
    evals,evecs = np.linalg.eig(c_pca)
    
    fig6 = plt.figure()
    plt.plot(s_go0[:,0,:].dot(evecs[:,0:1]),s_go0[:,0,:].dot(evecs[:,1:2]),'g',alpha=.5)
    plt.plot(s_go1[:,0,:].dot(evecs[:,0:1]),s_go1[:,0,:].dot(evecs[:,1:2]),'b',alpha=.5)
    plt.plot(s_in0[:,0,:].dot(evecs[:,0:1]),s_in0[:,0,:].dot(evecs[:,1:2]),'c',alpha=.5)
    plt.plot(s_in1[:,0,:].dot(evecs[:,0:1]),s_in1[:,0,:].dot(evecs[:,1:2]),'r',alpha=.5)
    
    plt.plot(s_go1[:,0,:].dot(evecs[:,0:1])[0],s_go1[:,0,:].dot(evecs[:,1:2])[0],'kx',markersize=10)
    
    plt.plot(s_go0[:,0,:].dot(evecs[:,0:1])[49],s_go0[:,0,:].dot(evecs[:,1:2])[49],'og',markersize=5)
    plt.plot(s_go0[:,0,:].dot(evecs[:,0:1])[149],s_go0[:,0,:].dot(evecs[:,1:2])[149],'og',markersize=5)
    
    plt.plot(s_go1[:,0,:].dot(evecs[:,0:1])[49],s_go1[:,0,:].dot(evecs[:,1:2])[49],'xb',markersize=8)
    plt.plot(s_go1[:,0,:].dot(evecs[:,0:1])[149],s_go1[:,0,:].dot(evecs[:,1:2])[149],'xb',markersize=8)
    
    plt.xlabel('pc1')
    plt.ylabel('pc2')
    
    plt.legend(['in_go_0','in_go_1','in_0','in_1'],loc='lower left',fontsize=8)
    
    pp.savefig(fig6)
    
    #Figure 7 Hamming Distance btw Putative Fixed Points
    
    masks = s[:,0,:].T>0
    x_hat = np.zeros(masks.shape)

    for ii in range(masks.shape[1]):
        Weff = W*masks[:,ii]
        x_hat[:,ii] = np.linalg.inv(np.eye(100)-Weff).dot(brec)
    
    fig7 = plt.figure()
    plt.pcolormesh(squareform(pdist(np.sign(x_hat[:,:]).T,metric='hamming'))) #,vmax=.3)
    plt.colorbar()
    plt.ylim([0,x_hat.shape[1]])
    plt.xlim([0,x_hat.shape[1]])
    
    plt.title('Hamming Distance Between Putative FPs')
    plt.ylabel('Time')
    plt.xlabel('Time')
    
    pp.savefig(fig7)
    
    pp.close()
예제 #4
0
def analysis_and_write(params,weights_path,fig_directory,run_name,no_rec_noise=True):
    
    from matplotlib.backends.backend_pdf import PdfPages
    import os
    import copy
    
    original_params = copy.deepcopy(params)
    
    if no_rec_noise:
        params['rec_noise'] = 0.0
    
    try:
        os.stat(fig_directory)
    except:
        os.mkdir(fig_directory)
        
    pp = PdfPages(fig_directory + '/' + run_name + '.pdf')

    params['sample_size'] = 2000
    generator = generate_train_trials(params)
    weights = np.load(weights_path)
    
    W = weights['W_rec']
    Win = weights['W_in']
    Wout = weights['W_out']
    brec = weights['b_rec'] 
    
    #Generate Input Data
    data = generator.next()
    #Find Input/Target One-Hot
    inp = np.argmax(data[0][:,40,:],axis=1)
    
    sim = Simulator(params, weights_path=weights_path)
    output,states = sim.run_trial(data[0][0,:,:],t_connectivity=False)
    
    n_in = n_out = data[0].shape[2]
    n_rec = W.shape[0]
    
    #generate trials
    s = np.zeros([data[0].shape[1],data[0].shape[0],W.shape[0]])
    for ii in range(data[0].shape[0]):
        s[:,ii,:] = sim.run_trial(data[0][ii,:,:],t_connectivity=False)[1].reshape([data[0].shape[1],W.shape[0]])
    
    #generate long duration trials
    long_in = np.zeros([10000,n_in,n_in])
    for ii in range(n_in):
        long_in[10:80,ii,ii] = 1

    s_long = np.zeros([long_in.shape[0],long_in.shape[1],W.shape[0]])
    for ii in range(n_in):
        s_long[:,ii,:] = sim.run_trial(long_in[:,ii,:],t_connectivity=False)[1].reshape([long_in.shape[0],W.shape[0]])
    
    #Figure 0 (Plot Params)
    fig0 = plot_params(original_params)
    pp.savefig(fig0)

    #Figure 1 (Single Trial (Input Output State))
    fig1 = plot_single_trial(data,states,output)
    pp.savefig(fig1)    
    
    #Figure 2 (plot fixed points - activity at end of trial)
    fig2 = plot_fps_vs_activity(s,W,brec)
    pp.savefig(fig2)
    
    #Figure 3 (Plot output activity)
    try:
        fig3 = plot_outputs_by_input(s,data,weights,n=Win.shape[1])
        pp.savefig(fig3)
    except Exception:
        pass
    
    #Figure 4 (Plot 2D PCA projection)
    fig4 = pca_plot(n_in,s_long,s,inp,brec)
    pp.savefig(fig4)
    
    #Figure5 (Plot Long Output)
    fig5 = plot_long_output_by_input(n_in,n_rec,s_long,weights)
    pp.savefig(fig5)
    
    #Figure6 (Plot ablation analysis)
    fig6 = ablation_analysis(n_rec,n_in,weights,sim)
    pp.savefig(fig6)
    
    #Figure7 (Plot W Structure)
    fig7 = plot_structure_Wrec(W)
    pp.savefig(fig7)
    
    #Figure8 (Bar Plot of distance to nearest partition)
    fig8 = plot_dist_to_fp(s_long)
    pp.savefig(fig8)
    
    fig9 = plot_fp_partitions(s_long)
    pp.savefig(fig9)
    
    
    pp.close()
예제 #5
0
    model.train(sess,
                generator,
                learning_rate=learning_rate,
                training_iters=training_iters,
                weights_path=weights_path,
                display_step=display_step)

    data = generator.next()

    #W = model.W_rec.eval(session=sess)
    #U = model.W_in.eval(session=sess)
    #Z = model.W_out.eval(session=sess)
    #brec = model.b_rec.eval(session=sess)
    #bout = model.b_out.eval(session=sess)

    sim = Simulator(params, weights_path=weights_path)
    output, states = sim.run_trial(data[0][0, :, :], t_connectivity=False)

    x_test, y_test, mask = build_test_trials(params)
    mup, mdown, choice, resp = white_noise_test(sim, x_test)
    coh_out = coherence_test(sim, np.arange(-.2, .2, .01))

    for i in range(5):
        trial = data[0][i, :, :]

        points = analysis.hahnloser_fixed_point(sim, trial)

        analysis.plot_states(states=states, I=points)

    sess.close()
예제 #6
0
                generator,
                learning_rate=learning_rate,
                training_iters=training_iters,
                weights_path=weights_path,
                display_step=display_step)

    data = generator.next()
    #output,states = model.test(sess, input, weights_path = weights_path)

    W = model.W_rec.eval(session=sess)
    U = model.W_in.eval(session=sess)
    Z = model.W_out.eval(session=sess)
    brec = model.b_rec.eval(session=sess)
    bout = model.b_out.eval(session=sess)

    sim = Simulator(params, weights_path=weights_path)
    output, states = sim.run_trial(data[0][0, :, :], t_connectivity=False)

    s = np.zeros([states.shape[0], batch_size, n_hidden])
    for ii in range(batch_size):
        s[:, ii, :] = sim.run_trial(data[0][ii, :, :],
                                    t_connectivity=False)[1].reshape(
                                        [states.shape[0], n_hidden])

    n_noise = 1000
    s_noise = np.zeros([states.shape[0], n_noise, n_hidden])
    data_noise = stim_noise * np.random.randn(n_noise, states.shape[0], n_in)
    for ii in range(n_noise):
        s_noise[:, ii, :] = sim.run_trial(data_noise[ii, :, :],
                                          t_connectivity=False)[1].reshape(
                                              [states.shape[0], n_hidden])