コード例 #1
0
SD_err = np.nanstd(np.squeeze(all_SD), axis=0)  #/np.sqrt(len(all_SD))

output_dics = []
for d in task.positives:
    output_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                                for p in pos_conds])[0][0])

input_dics = []
for d in input_task.positives:
    input_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
              for p in pos_conds])[0][0])

dicplt.dichotomy_plot(PS,
                      CCGP,
                      SD,
                      input_dics=input_dics,
                      output_dics=output_dics,
                      other_dics=[pos_conds.index((0, 2, 5, 7))],
                      out_MI=out_MI.mean(0))

#%%
if two_layers:
    z1 = nonlinearity(torch.matmul(W1, inputs.T) + b1).detach().numpy()
    z = nonlinearity(torch.matmul(W2, torch.tensor(z1)) +
                     b2).detach().numpy().T
else:
    z = nonlinearity(torch.matmul(W1, inputs.T) + b1).detach().numpy().T

x_ = np.stack([
    inputs[inp_condition == i, :].mean(0).detach().numpy()
    for i in np.unique(conds)
コード例 #2
0
ファイル: analysis_script.py プロジェクト: Kelarion/repler
    SD_err = np.nanstd(np.squeeze(all_SD), axis=0)#/np.sqrt(len(all_SD))
    
    output_dics = []
    for d in task.positives:
        output_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                                    for p in pos_conds])[0][0])
    if 'bits' in dir(task):
        input_dics = []
        for d in task.bits:
            input_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                      for p in pos_conds])[0][0])
    else:
        input_dics = None

    dicplt.dichotomy_plot(PS, CCGP, SD, PS_err=PS_err, CCGP_err=CCGP_err, SD_err=SD_err,
                          input_dics=input_dics, output_dics=output_dics, 
                           out_MI=out_MI.mean(0))

else:
    nrow = int(np.sqrt(len(all_PS)))
    ncol = len(all_PS)//nrow
    
    plt.figure()
    for i, (PS, CCGP, SD) in enumerate(zip(all_PS, almost_all_CCGP, all_SD)):
        output_dics = []
        for d in task.positives:
            output_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                                        for p in dic_pos[i]])[0][0])
        
        if 'bits' in dir(task):
            input_dics = []
コード例 #3
0
    for i, (PS, CCGP, SD) in enumerate(zip(all_PS, almost_all_CCGP, all_SD)):
        output_dics = []
        for d in output_task.positives:
            output_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                                        for p in pos_conds])[0][0])
        input_dics = []
        for d in input_task.positives:
            input_dics.append(np.where([(list(p) == list(d)) or (list(np.setdiff1d(range(num_cond),p))==list(d))\
                      for p in pos_conds])[0][0])

        plt.subplot(nrow, ncol, i + 1)
        dicplt.dichotomy_plot(PS,
                              CCGP,
                              SD,
                              input_dics=input_dics,
                              output_dics=output_dics,
                              other_dics=[pos_conds.index((0, 2, 5, 7))],
                              out_MI=out_MI[i],
                              include_legend=(i == 0),
                              include_cbar=False,
                              s=10)
        if np.mod(i, ncol) > 0:
            plt.yticks([])
        plt.ylabel('')
        if (i + 1) // nrow < nrow:
            plt.xlabel('')
            plt.xticks([])

#%%
n_compute = 10000
lag = 3