N_sim = 10 E_fix = np.zeros((N_sim, N_trial)) E_go = np.zeros((N_sim, N_trial)) conv_tr = np.zeros((N_sim)) perc_fix = np.zeros((N_sim)) perc_go = np.zeros((N_sim)) stop = True for n in np.arange(N_sim): print('SIMULATION ', n + 1) S_tr, O_tr, _, _, _, _ = data_construction(N=N_trial, perc_training=1) HER = HER_arch(NL, S, P, learn_rate_vec, learn_rate_memory, beta_vec, gamma, elig_decay_vec, dic_stim, dic_resp) E_fix[n, :], E_go[n, :], conv_tr[n] = HER.training_saccade( S_tr, O_tr, bias, 'softmax', stop) S_test, O_test, _, _, _, _ = data_construction(N=100, perc_training=1) perc_fix[n], perc_go[n] = HER.test_saccade(S_test, O_test, bias, 'softmax') print('\t Percentage of correct fix responses: ', perc_fix[n], '%') print('\t Percentage of correct go responses: ', perc_go[n], '%') E_fix_mean = np.mean(np.reshape(E_fix, (-1, 50)), axis=1) str_err_fix = data_folder + '/HER_long_' + task + 'error_fix_2.txt' np.savetxt(str_err_fix, E_fix_mean) E_go_mean = np.mean(np.reshape(E_go, (-1, 50)), axis=1) str_err_go = data_folder + '/HER_long_' + task + 'error_go_2.txt' np.savetxt(str_err_go, E_go_mean) str_conv = data_folder + '/HER_long_' + task + '_conv_2.txt' np.savetxt(str_conv, conv_tr)
def HER_task_saccades(params_bool, params_task): from TASKS.task_saccades import data_construction task = 'saccade' np.random.seed(1234) cues_vec = ['empty', 'P', 'A', 'L', 'R'] pred_vec = ['LC', 'LW', 'FC', 'FW', 'RC', 'RW'] if params_task is None: #N_trial = 15000 N_trial = 20000 perc_tr = 0.8 else: N_trial = int(params_task[0]) if params_task[0] != '' else 20000 perc_tr = float(params_task[1]) if params_task[1] != '' else 0.8 S_tr, O_tr, S_test, O_test, dic_stim, dic_resp = data_construction( N=N_trial, perc_training=perc_tr, model='1') ## CONSTRUCTION OF THE HER MULTI-LEVEL NETWORK NL = 3 # number of levels (<= 3) S = np.shape(S_tr)[1] # dimension of the input P = np.shape(O_tr)[1] # dimension of the prediction vector ### Parameter values come from Table 1 of the Supplementary Material of the paper "Frontal cortex function derives from hierarchical predictive coding", W. Alexander, J. Brown learn_rate_vec = [0.1, 0.02, 0.02] # learning rates learn_rate_memory = [0.1, 0.1, 0.1] beta_vec = [12, 12, 12] # gain parameter for memory dynamics gamma = 12 # gain parameter for response making elig_decay_vec = [0.3, 0.5, 0.9] # decay factors for eligibility trace bias = [0, 0, 0] gate = 'softmax' verb = 0 if params_bool is None: do_training = True do_test = True do_weight_plots = True do_error_plots = True else: do_training = params_bool[0] do_test = params_bool[1] do_weight_plots = params_bool[2] do_error_plots = params_bool[3] HER = HER_arch(NL, S, P, learn_rate_vec, learn_rate_memory, beta_vec, gamma, elig_decay_vec, dic_stim, dic_resp) HER.print_HER(False) #print(S_tr[:20,:]) ## TRAINING data_folder = 'HER/DATA' N_training = np.around(N_trial * perc_tr).astype(int) if do_training: E_fix, E_go, conv_iter = HER.training_saccade(N_training, S_tr, O_tr, bias, gate) # save trained model str_err = data_folder + '/' + task + '_error_fix.txt' np.savetxt(str_err, E_fix) str_err = data_folder + '/' + task + '_error_go.txt' np.savetxt(str_err, E_go) str_conv = data_folder + '/' + task + '_conv.txt' np.savetxt(str_conv, conv_iter) for l in np.arange(NL): str_mem = data_folder + '/' + task + '_weights_memory_' + str( l) + '.txt' np.savetxt(str_mem, HER.H[l].X) str_pred = data_folder + '/' + task + '_weights_prediction_' + str( l) + '.txt' np.savetxt(str_pred, HER.H[l].W) print("\nSaved model to disk.\n") else: str_err = data_folder + '/' + task + '_error_fix.txt' E_fix = np.loadtxt(str_err) str_err = data_folder + '/' + task + '_error_go.txt' E_go = np.loadtxt(str_err) str_conv = data_folder + '/' + task + '_conv.txt' conv_iter = np.loadtxt(str_conv) for l in np.arange(NL): str_mem = data_folder + '/' + task + '_weights_memory_' + str( l) + '.txt' HER.H[l].X = np.loadtxt(str_mem) str_pred = data_folder + '/' + task + '_weights_prediction_' + str( l) + '.txt' HER.H[l].W = np.loadtxt(str_pred) print("\nLoaded model from disk.\n") print( '\n----------------------------------------------------\nTEST\n------------------------------------------------------\n' ) ## TEST if do_test: N_test = N_trial - N_training HER.test_saccade(N_test, S_test, O_test, bias, verb, gate) print(conv_iter) ## PLOTS # plot of the memory weights image_folder = 'HER/IMAGES' fontTitle = 26 fontTicks = 22 fontLabel = 22 if do_weight_plots: fig1 = plt.figure(figsize=(10 * NL, 8)) for l in np.arange(NL): X = HER.H[l].X plt.subplot(1, NL, l + 1) plt.pcolor(np.flipud(X), edgecolors='k', linewidths=1) plt.set_cmap('Blues') plt.colorbar() tit = 'MEMORY WEIGHTS: Level ' + str(l) plt.title(tit, fontweight="bold", fontsize=fontTitle) plt.xticks(np.linspace(0.5, S - 0.5, S, endpoint=True), cues_vec, fontsize=fontTicks) plt.yticks(np.linspace(0.5, S - 0.5, S, endpoint=True), np.flipud(cues_vec), fontsize=fontTicks) plt.show() savestr = image_folder + '/' + task + '_weights_memory.png' if gate == 'free': savestr = image_folder + '/' + task + '_weights_memory_nomemory.png' fig1.savefig(savestr) fig2 = plt.figure(figsize=(10 * NL, 8)) for l in np.arange(NL): W = HER.H[l].W plt.subplot(1, NL, l + 1) plt.pcolor(np.flipud(W), edgecolors='k', linewidths=1) plt.set_cmap('Blues') plt.colorbar() tit = 'PREDICTION WEIGHTS: Level ' + str(l) plt.title(tit, fontweight="bold", fontsize=fontTitle) if l == 0: plt.xticks(np.linspace(0.5, np.shape(W)[1] - 0.5, P, endpoint=True), pred_vec, fontsize=fontTicks) else: dx = np.shape(W)[1] / (2 * S) plt.xticks(np.linspace(dx, np.shape(W)[1] - dx, S, endpoint=True), cues_vec, fontsize=fontTicks) plt.yticks(np.linspace(0.5, S - 0.5, S, endpoint=True), np.flipud(cues_vec), fontsize=fontTicks) plt.show() savestr = image_folder + '/' + task + '_weights_prediction.png' if gate == 'free': savestr = image_folder + '/' + task + '_weights_prediction_nomemory.png' fig2.savefig(savestr) if do_error_plots: N = len(E_fix) bin = round(N * 0.02) print(bin) E_fix_bin = np.reshape(E_fix, (-1, bin)) E_fix_bin = np.sum(E_fix_bin, axis=1) E_fix_cum = np.cumsum(E_fix) E_fix_norm = 100 * E_fix_cum / (np.arange(N) + 1) C_fix = np.where(E_fix == 0, 1, 0) C_fix_cum = 100 * np.cumsum(C_fix) / (np.arange(N) + 1) E_go_bin = np.reshape(E_go, (-1, bin)) E_go_bin = np.sum(E_go_bin, axis=1) E_go_cum = np.cumsum(E_go) E_go_norm = 100 * E_go_cum / (np.arange(N) + 1) C_go = np.where(E_go == 0, 1, 0) C_go_cum = 100 * np.cumsum(C_go) / (np.arange(N) + 1) figE_fix = plt.figure(figsize=(22, 8)) plt.subplot(1, 2, 1) plt.bar(bin * np.arange(len(E_fix_bin)), E_fix_bin, width=bin, color='blue', edgecolor='black', label='fix', alpha=0.6) plt.axvline(x=225, linewidth=5, ls='dashed', color='orange') plt.axvline(x=0, linewidth=5, color='b') tit = 'SAS: Training Convergence for FIX' plt.title(tit, fontweight="bold", fontsize=fontTitle) plt.xlabel('Training Trials', fontsize=fontLabel) plt.ylabel('Number of Errors per bin', fontsize=fontLabel) plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks) plt.yticks(fontsize=fontTicks) text = 'Bin = ' + str(bin) plt.ylim((0, 130)) plt.figtext(x=0.37, y=0.78, s=text, fontsize=fontLabel, bbox={ 'facecolor': 'white', 'alpha': 0.5, 'pad': 10 }) plt.subplot(1, 2, 2) plt.axvline(x=225, linewidth=5, ls='dashed', color='orange') plt.plot(np.arange(N), E_fix_cum, color='blue', linewidth=7, label='fix', alpha=0.6) plt.axvline(x=0, linewidth=5, color='b') tit = 'SAS: Cumulative Training Error for FIX' plt.title(tit, fontweight="bold", fontsize=fontTitle) plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks) plt.yticks(fontsize=fontTicks) plt.xlabel('Training Trials', fontsize=fontLabel) plt.ylabel('Cumulative Error', fontsize=fontLabel) plt.ylim((0, 550)) plt.show() savestr = image_folder + '/' + task + '_error_fix.png' if gate == 'free': savestr = image_folder + '/' + task + '_error_nomemory_fix.png' figE_fix.savefig(savestr) figE_go = plt.figure(figsize=(22, 8)) plt.subplot(1, 2, 1) plt.bar(bin * np.arange(len(E_go_bin)), E_go_bin, width=bin, color='blue', edgecolor='black', alpha=0.6) plt.axvline(x=4100, linewidth=5, ls='dashed', color='green') if conv_iter != 0: plt.axvline(x=conv_iter, linewidth=5, color='b') tit = 'SAS: Training Convergence for GO' plt.title(tit, fontweight="bold", fontsize=fontTitle) plt.xlabel('Training Trials', fontsize=fontLabel) plt.ylabel('Number of Errors per bin', fontsize=fontLabel) plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks) plt.yticks(fontsize=fontTicks) text = 'Bin = ' + str(bin) plt.figtext(x=0.37, y=0.78, s=text, fontsize=fontLabel, bbox={ 'facecolor': 'white', 'alpha': 0.5, 'pad': 10 }) plt.subplot(1, 2, 2) plt.axvline(x=4100, linewidth=5, ls='dashed', color='green') plt.plot(np.arange(N), E_go_cum, color='blue', linewidth=7, alpha=0.6) if conv_iter != 0: plt.axvline(x=conv_iter, linewidth=5, color='b') tit = 'SAS: Cumulative Training Error for GO' plt.title(tit, fontweight="bold", fontsize=fontTitle) plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks) plt.yticks(fontsize=fontTicks) plt.xlabel('Training Trials', fontsize=fontLabel) plt.ylabel('Cumulative Error', fontsize=fontLabel) plt.show() savestr = image_folder + '/' + task + '_error_go.png' if gate == 'free': savestr = image_folder + '/' + task + '_error_nomemory_go.png' figE_go.savefig(savestr)