Example #1
0
N_sim = 10
E_fix = np.zeros((N_sim, N_trial))
E_go = np.zeros((N_sim, N_trial))
conv_tr = np.zeros((N_sim))
perc_fix = np.zeros((N_sim))
perc_go = np.zeros((N_sim))
stop = True

for n in np.arange(N_sim):

    print('SIMULATION ', n + 1)
    S_tr, O_tr, _, _, _, _ = data_construction(N=N_trial, perc_training=1)

    HER = HER_arch(NL, S, P, learn_rate_vec, learn_rate_memory, beta_vec,
                   gamma, elig_decay_vec, dic_stim, dic_resp)
    E_fix[n, :], E_go[n, :], conv_tr[n] = HER.training_saccade(
        S_tr, O_tr, bias, 'softmax', stop)

    S_test, O_test, _, _, _, _ = data_construction(N=100, perc_training=1)
    perc_fix[n], perc_go[n] = HER.test_saccade(S_test, O_test, bias, 'softmax')

    print('\t Percentage of correct fix responses: ', perc_fix[n], '%')
    print('\t Percentage of correct go responses: ', perc_go[n], '%')

E_fix_mean = np.mean(np.reshape(E_fix, (-1, 50)), axis=1)
str_err_fix = data_folder + '/HER_long_' + task + 'error_fix_2.txt'
np.savetxt(str_err_fix, E_fix_mean)
E_go_mean = np.mean(np.reshape(E_go, (-1, 50)), axis=1)
str_err_go = data_folder + '/HER_long_' + task + 'error_go_2.txt'
np.savetxt(str_err_go, E_go_mean)
str_conv = data_folder + '/HER_long_' + task + '_conv_2.txt'
np.savetxt(str_conv, conv_tr)
Example #2
0
def HER_task_saccades(params_bool, params_task):

    from TASKS.task_saccades import data_construction
    task = 'saccade'

    np.random.seed(1234)
    cues_vec = ['empty', 'P', 'A', 'L', 'R']
    pred_vec = ['LC', 'LW', 'FC', 'FW', 'RC', 'RW']

    if params_task is None:
        #N_trial = 15000
        N_trial = 20000
        perc_tr = 0.8
    else:
        N_trial = int(params_task[0]) if params_task[0] != '' else 20000
        perc_tr = float(params_task[1]) if params_task[1] != '' else 0.8

    S_tr, O_tr, S_test, O_test, dic_stim, dic_resp = data_construction(
        N=N_trial, perc_training=perc_tr, model='1')

    ## CONSTRUCTION OF THE HER MULTI-LEVEL NETWORK
    NL = 3  # number of levels (<= 3)
    S = np.shape(S_tr)[1]  # dimension of the input
    P = np.shape(O_tr)[1]  # dimension of the prediction vector

    ### Parameter values come from Table 1 of the Supplementary Material of the paper "Frontal cortex function derives from hierarchical predictive coding", W. Alexander, J. Brown
    learn_rate_vec = [0.1, 0.02, 0.02]  # learning rates
    learn_rate_memory = [0.1, 0.1, 0.1]
    beta_vec = [12, 12, 12]  # gain parameter for memory dynamics
    gamma = 12  # gain parameter for response making
    elig_decay_vec = [0.3, 0.5, 0.9]  # decay factors for eligibility trace
    bias = [0, 0, 0]

    gate = 'softmax'

    verb = 0

    if params_bool is None:
        do_training = True
        do_test = True
        do_weight_plots = True
        do_error_plots = True

    else:
        do_training = params_bool[0]
        do_test = params_bool[1]
        do_weight_plots = params_bool[2]
        do_error_plots = params_bool[3]

    HER = HER_arch(NL, S, P, learn_rate_vec, learn_rate_memory, beta_vec,
                   gamma, elig_decay_vec, dic_stim, dic_resp)
    HER.print_HER(False)
    #print(S_tr[:20,:])

    ## TRAINING
    data_folder = 'HER/DATA'
    N_training = np.around(N_trial * perc_tr).astype(int)
    if do_training:

        E_fix, E_go, conv_iter = HER.training_saccade(N_training, S_tr, O_tr,
                                                      bias, gate)

        # save trained model
        str_err = data_folder + '/' + task + '_error_fix.txt'
        np.savetxt(str_err, E_fix)
        str_err = data_folder + '/' + task + '_error_go.txt'
        np.savetxt(str_err, E_go)
        str_conv = data_folder + '/' + task + '_conv.txt'
        np.savetxt(str_conv, conv_iter)
        for l in np.arange(NL):
            str_mem = data_folder + '/' + task + '_weights_memory_' + str(
                l) + '.txt'
            np.savetxt(str_mem, HER.H[l].X)
            str_pred = data_folder + '/' + task + '_weights_prediction_' + str(
                l) + '.txt'
            np.savetxt(str_pred, HER.H[l].W)
        print("\nSaved model to disk.\n")

    else:

        str_err = data_folder + '/' + task + '_error_fix.txt'
        E_fix = np.loadtxt(str_err)
        str_err = data_folder + '/' + task + '_error_go.txt'
        E_go = np.loadtxt(str_err)

        str_conv = data_folder + '/' + task + '_conv.txt'
        conv_iter = np.loadtxt(str_conv)
        for l in np.arange(NL):
            str_mem = data_folder + '/' + task + '_weights_memory_' + str(
                l) + '.txt'
            HER.H[l].X = np.loadtxt(str_mem)
            str_pred = data_folder + '/' + task + '_weights_prediction_' + str(
                l) + '.txt'
            HER.H[l].W = np.loadtxt(str_pred)
        print("\nLoaded model from disk.\n")

    print(
        '\n----------------------------------------------------\nTEST\n------------------------------------------------------\n'
    )

    ## TEST
    if do_test:
        N_test = N_trial - N_training
        HER.test_saccade(N_test, S_test, O_test, bias, verb, gate)
        print(conv_iter)

    ## PLOTS
    # plot of the memory weights
    image_folder = 'HER/IMAGES'
    fontTitle = 26
    fontTicks = 22
    fontLabel = 22

    if do_weight_plots:
        fig1 = plt.figure(figsize=(10 * NL, 8))
        for l in np.arange(NL):
            X = HER.H[l].X
            plt.subplot(1, NL, l + 1)
            plt.pcolor(np.flipud(X), edgecolors='k', linewidths=1)
            plt.set_cmap('Blues')
            plt.colorbar()
            tit = 'MEMORY WEIGHTS: Level ' + str(l)
            plt.title(tit, fontweight="bold", fontsize=fontTitle)
            plt.xticks(np.linspace(0.5, S - 0.5, S, endpoint=True),
                       cues_vec,
                       fontsize=fontTicks)
            plt.yticks(np.linspace(0.5, S - 0.5, S, endpoint=True),
                       np.flipud(cues_vec),
                       fontsize=fontTicks)
        plt.show()
        savestr = image_folder + '/' + task + '_weights_memory.png'
        if gate == 'free':
            savestr = image_folder + '/' + task + '_weights_memory_nomemory.png'
        fig1.savefig(savestr)

        fig2 = plt.figure(figsize=(10 * NL, 8))
        for l in np.arange(NL):
            W = HER.H[l].W
            plt.subplot(1, NL, l + 1)
            plt.pcolor(np.flipud(W), edgecolors='k', linewidths=1)
            plt.set_cmap('Blues')
            plt.colorbar()
            tit = 'PREDICTION WEIGHTS: Level ' + str(l)
            plt.title(tit, fontweight="bold", fontsize=fontTitle)
            if l == 0:
                plt.xticks(np.linspace(0.5,
                                       np.shape(W)[1] - 0.5,
                                       P,
                                       endpoint=True),
                           pred_vec,
                           fontsize=fontTicks)
            else:
                dx = np.shape(W)[1] / (2 * S)
                plt.xticks(np.linspace(dx,
                                       np.shape(W)[1] - dx,
                                       S,
                                       endpoint=True),
                           cues_vec,
                           fontsize=fontTicks)
            plt.yticks(np.linspace(0.5, S - 0.5, S, endpoint=True),
                       np.flipud(cues_vec),
                       fontsize=fontTicks)
        plt.show()
        savestr = image_folder + '/' + task + '_weights_prediction.png'
        if gate == 'free':
            savestr = image_folder + '/' + task + '_weights_prediction_nomemory.png'
        fig2.savefig(savestr)

    if do_error_plots:

        N = len(E_fix)
        bin = round(N * 0.02)
        print(bin)
        E_fix_bin = np.reshape(E_fix, (-1, bin))
        E_fix_bin = np.sum(E_fix_bin, axis=1)
        E_fix_cum = np.cumsum(E_fix)
        E_fix_norm = 100 * E_fix_cum / (np.arange(N) + 1)
        C_fix = np.where(E_fix == 0, 1, 0)
        C_fix_cum = 100 * np.cumsum(C_fix) / (np.arange(N) + 1)

        E_go_bin = np.reshape(E_go, (-1, bin))
        E_go_bin = np.sum(E_go_bin, axis=1)
        E_go_cum = np.cumsum(E_go)
        E_go_norm = 100 * E_go_cum / (np.arange(N) + 1)
        C_go = np.where(E_go == 0, 1, 0)
        C_go_cum = 100 * np.cumsum(C_go) / (np.arange(N) + 1)

        figE_fix = plt.figure(figsize=(22, 8))
        plt.subplot(1, 2, 1)
        plt.bar(bin * np.arange(len(E_fix_bin)),
                E_fix_bin,
                width=bin,
                color='blue',
                edgecolor='black',
                label='fix',
                alpha=0.6)
        plt.axvline(x=225, linewidth=5, ls='dashed', color='orange')
        plt.axvline(x=0, linewidth=5, color='b')
        tit = 'SAS: Training Convergence for FIX'
        plt.title(tit, fontweight="bold", fontsize=fontTitle)
        plt.xlabel('Training Trials', fontsize=fontLabel)
        plt.ylabel('Number of Errors per bin', fontsize=fontLabel)
        plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks)
        plt.yticks(fontsize=fontTicks)
        text = 'Bin = ' + str(bin)
        plt.ylim((0, 130))
        plt.figtext(x=0.37,
                    y=0.78,
                    s=text,
                    fontsize=fontLabel,
                    bbox={
                        'facecolor': 'white',
                        'alpha': 0.5,
                        'pad': 10
                    })

        plt.subplot(1, 2, 2)
        plt.axvline(x=225, linewidth=5, ls='dashed', color='orange')
        plt.plot(np.arange(N),
                 E_fix_cum,
                 color='blue',
                 linewidth=7,
                 label='fix',
                 alpha=0.6)
        plt.axvline(x=0, linewidth=5, color='b')
        tit = 'SAS: Cumulative Training Error for FIX'
        plt.title(tit, fontweight="bold", fontsize=fontTitle)
        plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks)
        plt.yticks(fontsize=fontTicks)
        plt.xlabel('Training Trials', fontsize=fontLabel)
        plt.ylabel('Cumulative Error', fontsize=fontLabel)
        plt.ylim((0, 550))
        plt.show()

        savestr = image_folder + '/' + task + '_error_fix.png'
        if gate == 'free':
            savestr = image_folder + '/' + task + '_error_nomemory_fix.png'
        figE_fix.savefig(savestr)

        figE_go = plt.figure(figsize=(22, 8))
        plt.subplot(1, 2, 1)
        plt.bar(bin * np.arange(len(E_go_bin)),
                E_go_bin,
                width=bin,
                color='blue',
                edgecolor='black',
                alpha=0.6)
        plt.axvline(x=4100, linewidth=5, ls='dashed', color='green')
        if conv_iter != 0:
            plt.axvline(x=conv_iter, linewidth=5, color='b')
        tit = 'SAS: Training Convergence for GO'
        plt.title(tit, fontweight="bold", fontsize=fontTitle)
        plt.xlabel('Training Trials', fontsize=fontLabel)
        plt.ylabel('Number of Errors per bin', fontsize=fontLabel)
        plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks)
        plt.yticks(fontsize=fontTicks)
        text = 'Bin = ' + str(bin)
        plt.figtext(x=0.37,
                    y=0.78,
                    s=text,
                    fontsize=fontLabel,
                    bbox={
                        'facecolor': 'white',
                        'alpha': 0.5,
                        'pad': 10
                    })

        plt.subplot(1, 2, 2)
        plt.axvline(x=4100, linewidth=5, ls='dashed', color='green')
        plt.plot(np.arange(N), E_go_cum, color='blue', linewidth=7, alpha=0.6)
        if conv_iter != 0:
            plt.axvline(x=conv_iter, linewidth=5, color='b')
        tit = 'SAS: Cumulative Training Error for GO'
        plt.title(tit, fontweight="bold", fontsize=fontTitle)
        plt.xticks(np.linspace(0, N, 5, endpoint=True), fontsize=fontTicks)
        plt.yticks(fontsize=fontTicks)
        plt.xlabel('Training Trials', fontsize=fontLabel)
        plt.ylabel('Cumulative Error', fontsize=fontLabel)
        plt.show()

        savestr = image_folder + '/' + task + '_error_go.png'
        if gate == 'free':
            savestr = image_folder + '/' + task + '_error_nomemory_go.png'
        figE_go.savefig(savestr)