Beispiel #1
0
def network_lesioning(filename):
    """ Lesion individual neurons for linking with gating patterns """

    results, savefile = load_and_replace_parameters(filename)
    model_module, sess, model, x, y, m, g, trial_mask, lid = load_tensorflow_model(
    )

    lesioning_results = np.zeros([par['n_hidden'], par['n_tasks']])
    base_accuracies = np.zeros([par['n_tasks']])

    import stimulus
    stim = stimulus.MultiStimulus()

    for task in range(par['n_tasks']):

        _, stim_in, y_hat, mk, _ = stim.generate_trial(task)
        feed_dict = {x: stim_in, y: y_hat, g: par['gating'][0], m: mk}

        output = sess.run(model.output, feed_dict=feed_dict)
        acc = model_module.get_perf(y_hat, output, mk)

        print('\n' + '-' * 60)
        print('Base accuracy for task {}: {}'.format(task, acc))
        base_accuracies[task] = acc

        for n in range(par['n_hidden']):
            print('Lesioning neuron {}/{}'.format(n, par['n_hidden']),
                  end='\r')
            sess.run(model.lesion_neuron, feed_dict={lid: n})
            lesioning_results[n, task] = model_module.get_perf(
                y_hat, sess.run(model.output, feed_dict=feed_dict), mk)
            load_model_weights(sess)

    return base_accuracies, lesioning_results
Beispiel #2
0
    def __init__(self, task_name_or_index):
        task_dict = {
            'go': 0,
            'rt_go': 1,
            'dly_go': 2,
            'anti-go': 3,
            'anti-rt_go': 4,
            'anti-dly_go': 5,
            'dm1': 6,
            'dm2': 7,
            'ctx_dm1': 8,
            'ctx_dm2': 9,
            'multsen_dm': 10,
            'dm1_dly': 11,
            'dm2_dly': 12,
            'ctx_dm1_dly': 13,
            'ctx_dm2_dly': 14,
            'multsen_dm_dly': 15,
            'dms': 16,
            'dmc': 17,
            'dnms': 18,
            'dnmc': 19
        }

        if isinstance(task_name_or_index, str):
            self.task_index = task_dict[task_name_or_index]
        else:
            self.task_index = task_name_or_index

        self.stim = stimulus.MultiStimulus()
Beispiel #3
0
def train_loop(model, opt, task_name_or_index, n_iters):
    task_dict = {
        'go': 0,
        'rt_go': 1,
        'dly_go': 2,
        'anti-go': 3,
        'anti-rt_go': 4,
        'anti-dly_go': 5,
        'dm1': 6,
        'dm2': 7,
        'ctx_dm1': 8,
        'ctx_dm2': 9,
        'multsen_dm': 10,
        'dm1_dly': 11,
        'dm2_dly': 12,
        'ctx_dm1_dly': 13,
        'ctx_dm2_dly': 14,
        'multsen_dm_dly': 15,
        'dms': 16,
        'dmc': 17,
        'dnms': 18,
        'dnmc': 19
    }

    if isinstance(task_name_or_index, str):
        task_index = task_dict[task_name_or_index]
    else:
        task_index = task_name_or_index

    stim = stimulus.MultiStimulus()

    for i in range(n_iters):
        """
        # add sanity checks here
        if i%50 == 0:
            w_rnn = model.get_layer('rnn').W_rnn
            pct_nonzero_weights = tf.math.count_nonzero(w_rnn) / (w_rnn.shape[0]*w_rnn.shape[1])
            diagonal_sum = tf.reduce_sum([w_rnn[i,i] for i in range(tf.math.reduce_min([w_rnn.shape[0], w_rnn.shape[1]]))])
            tf.print('diag sum: ', diagonal_sum)
        """

        name, input_data, ytrue_data, dead_time_mask, reward_data = \
            stim.generate_trial(task_index)
        input_data = tf.constant(input_data, dtype=tf.float32)
        ytrue_data = tf.constant(ytrue_data, dtype=tf.float32)
        dead_time_mask = tf.constant(dead_time_mask, dtype=tf.float32)
        loss, acc = train_step(model,
                               opt,
                               input_data,
                               ytrue_data,
                               mask=dead_time_mask)
        if i % 50 == 0:
            tf.print('Iter: ', i, '| Loss: ', loss, '| Acc: ', acc, '\n')
Beispiel #4
0
def task_variance_analysis(filename, plot=False):

    results, savefile = load_and_replace_parameters(filename)
    model_module, sess, model, x, y, m, g, trial_mask, lid = load_tensorflow_model(
    )

    lesioning_results = np.zeros([par['n_hidden'], par['n_tasks']])
    base_accuracies = np.zeros([par['n_tasks']])

    import stimulus
    stim = stimulus.MultiStimulus()

    task_variance = np.zeros([par['n_tasks'], par['n_hidden']])
    for task in range(par['n_tasks']):

        _, stim_in, y_hat, mk, _ = stim.generate_trial(task)
        feed_dict = {x: stim_in, y: y_hat, g: par['gating'][0], m: mk}

        output, h = sess.run([model.output, model.h], feed_dict=feed_dict)
        acc = model_module.get_perf(y_hat, output, mk)
        h = np.array(
            h)[par['dead_time'] //
               par['dt']:, :, :]  # [100, 256, 500], or [time, trials, neuron]

        task_variance[task, :] = np.mean(np.mean(
            np.square(h - np.mean(h, axis=1, keepdims=True)), axis=1),
                                         axis=0)

    if plot:
        plt.imshow(task_variance / np.amax(task_variance),
                   aspect='auto',
                   cmap='magma')
        plt.colorbar()
        plt.ylabel('Tasks')
        plt.yticks(np.arange(20))
        plt.xlabel('Neurons')
        plt.xticks(np.arange(500, 10))
        plt.title('Normalized Task Variance')
        plt.savefig('./records/task_variance.png')
        plt.clf()
        plt.close()

    return task_variance
Beispiel #5
0
def EWC_analysis(filename):
    """ Lesion individual neurons for linking with gating patterns """

    results, savefile = load_and_replace_parameters(filename)
    update_parameters({'stabilization': 'EWC'})
    update_parameters({'batch_size': 8})

    model_module, sess, model, x, y, m, g, trial_mask, lid = load_tensorflow_model(
    )
    EWC_results = []

    import stimulus
    stim = stimulus.MultiStimulus()

    import time
    for task in range(par['n_tasks']):
        print('EWC analysis for task {}.'.format(task))
        sess.run(model.reset_big_omega_vars)

        for n in range(par['EWC_fisher_num_batches']):
            print('EWC batch {}'.format(n), end='\r')
            _, stim_in, y_hat, mk, _ = stim.generate_trial(task)
            _, big_omegas = sess.run(
                [model.update_big_omega, model.big_omega_var],
                feed_dict={
                    x: stim_in,
                    y: y_hat,
                    g: par['gating'][0],
                    m: mk
                })

        EWC_results.append(big_omegas)

    pickle.dump(
        EWC_results,
        open(
            './records/EWC_results_weights_for_multistim_LSTM_without_WTA_gamma0_v0.pkl',
            'wb'))
Beispiel #6
0
def main():

    os.environ['CUDA_VISIBLE_DEVICES'] = '3'

    tf.reset_default_graph()

    x = tf.placeholder(tf.float32, [par['batch_size'], par['forward_shape'][0]], 'stim')
    y = tf.placeholder(tf.float32, [par['batch_size'], par['n_output']], 'out')
    info = tf.placeholder(tf.float32, [par['batch_size'], 1], 'info')
    alpha = tf.placeholder(tf.float32, [], 'alpha')

    #with tf.device('/gpu:0'):
    model = Model(x, y, info, alpha)
    #model = Model(x, y, alpha)

    stim = stimulus.MultiStimulus()

    iteration = []
    accuracy = []

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        train = True
        while train:
            for i in range(par['n_train_batches_gen']):

                # if i < 5000:
                #     par['subset_loc'] = False
                #     alpha_val = 0
                # if i%2 == 0:
                #     par['subset_loc'] = True
                #     alpha_val = 0.02
                # else:
                #     par['subset_loc'] = False
                #     alpha_val = 0
                alpha_val = 0.1
                name, inputs, neural_inputs, outputs = stim.generate_trial(0, False, False)
                if par['subset_dirs']:
                    task_info = np.float32(inputs[:,2]<6) * alpha_val
                elif par['subset_loc']:
                    task_info = np.float32(np.array(inputs[:,0]<5) * np.array(inputs[:,1]<5)) * alpha_val
                task_info = task_info[:,np.newaxis]

                feed_dict = {x:neural_inputs, y:outputs, info:task_info, alpha:alpha_val}

                if i%2 ==0:
                    _, loss, recon_loss, latent_loss, task_loss, y_hat, x_hat, mu, sigma, latent_sample, weight = sess.run([model.train_op_recon, model.task_loss, model.recon_loss, model.latent_loss, model.task_loss, model.y, model.x_hat, model.mu, model.si, model.latent_sample, model.var_dict['W_layer_out']], feed_dict=feed_dict)
                else:
                    _, loss, recon_loss, latent_loss, task_loss, y_hat, x_hat, mu, sigma, latent_sample, weight = sess.run([model.train_op_task, model.task_loss, model.recon_loss, model.latent_loss, model.task_loss, model.y, model.x_hat, model.mu, model.si, model.latent_sample, model.var_dict['W_layer_out']], feed_dict=feed_dict)



                if i%100 == 0:
                    acc = get_perf(outputs, y_hat)
                    #ind = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    if par['subset_dirs']:
                        ind = np.where(inputs[:,2]<6)[0]
                    elif par['subset_loc']:
                        ind = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    acc1 = get_perf(outputs[ind],y_hat[ind])
                    ind2 = np.setdiff1d(np.arange(256), ind)
                    acc2 = get_perf(outputs[ind2],y_hat[ind2])
                    print('{} | Reconstr. Loss: {:.3f} | Latent Loss: {:.3f} | Task Loss: {:.3f} | Accuracy: {:.3f} | Accuracy1: {:.3f} | Accuracy2: {:.3f} | <Sig>: {:.3f} +/- {:.3f}'.format( \
                    i, recon_loss, latent_loss, task_loss, acc, acc1, acc2, np.mean(sigma), np.std(sigma)))
                    iteration.append(i)
                    accuracy.append(acc)

                if i%100 == 1:
                    acc = get_perf(outputs, y_hat)
                    #ind = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    if par['subset_dirs']:
                        ind = np.where(inputs[:,2]<6)[0]
                    elif par['subset_loc']:
                        ind = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    acc1 = get_perf(outputs[ind],y_hat[ind])
                    ind2 = np.setdiff1d(np.arange(256), ind)
                    acc2 = get_perf(outputs[ind2],y_hat[ind2])
                    print('{} | Reconstr. Loss: {:.3f} | Latent Loss: {:.3f} | Task Loss: {:.3f} | Accuracy: {:.3f} | Accuracy1: {:.3f} | Accuracy2: {:.3f} | <Sig>: {:.3f} +/- {:.3f}'.format( \
                    i, recon_loss, latent_loss, task_loss, acc, acc1, acc2, np.mean(sigma), np.std(sigma)))
                    iteration.append(i)
                    accuracy.append(acc)

                if i%500 == 0:
                    # ind1 = all trials, ind2 = trials within the quadrant, ind3 = trials outside the quadrant
                    ind1 = np.arange(256)
                    #ind2 = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    if par['subset_dirs']:
                        ind2 = np.where(inputs[:,2]<5)[0]
                    elif par['subset_loc']:
                        ind2 = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    ind3 = np.setdiff1d(np.arange(256), ind)

                    index = [ind1, ind2, ind3]
                    for ind in index:
                        correlation = np.zeros((par['n_latent'], 7))
                        for l in range(par['n_latent']):
                            # for latent_sample in latent_sample:
                            correlation[l,0] += pearsonr(latent_sample[ind,l], inputs[ind,0])[0] #x
                            correlation[l,1] += pearsonr(latent_sample[ind,l], inputs[ind,1])[0] #y
                            correlation[l,2] += pearsonr(latent_sample[ind,l], inputs[ind,2])[0] #dir_ind
                            correlation[l,3] += pearsonr(latent_sample[ind,l], inputs[ind,3])[0] #m
                            correlation[l,4] += pearsonr(latent_sample[ind,l], inputs[ind,4])[0] #fix
                            correlation[l,5] += pearsonr(latent_sample[ind,l], outputs[ind,0])[0] #motion_x
                            correlation[l,6] += pearsonr(latent_sample[ind,l], outputs[ind,1])[0] #motion_y
                        print(['loc_x','loc_y','dir','m','fix','mot_x','mot_y'])
                        print(np.round(correlation,3))
                        print("")

                if i%500 == 1:
                    # ind1 = all trials, ind2 = trials within the quadrant, ind3 = trials outside the quadrant
                    ind1 = np.arange(256)
                    #ind2 = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    if par['subset_dirs']:
                        ind2 = np.where(inputs[:,2]<5)[0]
                    elif par['subset_loc']:
                        ind2 = np.intersect1d(np.argwhere(inputs[:,0]<5), np.argwhere(inputs[:,1]<5))
                    ind3 = np.setdiff1d(np.arange(256), ind)

                    index = [ind1, ind2, ind3]
                    for ind in index:
                        correlation = np.zeros((par['n_latent'], 7))
                        for l in range(par['n_latent']):
                            # for latent_sample in latent_sample:
                            correlation[l,0] += pearsonr(latent_sample[ind,l], inputs[ind,0])[0] #x
                            correlation[l,1] += pearsonr(latent_sample[ind,l], inputs[ind,1])[0] #y
                            correlation[l,2] += pearsonr(latent_sample[ind,l], inputs[ind,2])[0] #dir_ind
                            correlation[l,3] += pearsonr(latent_sample[ind,l], inputs[ind,3])[0] #m
                            correlation[l,4] += pearsonr(latent_sample[ind,l], inputs[ind,4])[0] #fix
                            correlation[l,5] += pearsonr(latent_sample[ind,l], outputs[ind,0])[0] #motion_x
                            correlation[l,6] += pearsonr(latent_sample[ind,l], outputs[ind,1])[0] #motion_y
                        print(['loc_x','loc_y','dir','m','fix','mot_x','mot_y'])
                        print(np.round(correlation,3))
                        print("")

                    print("example output") #outside of the quadrant
                    for q in range(10):
                        print(np.round(outputs[ind3][q],3), np.round(y_hat[ind3][q],3))
                    print("")

                    print("Weight: latent x 2")
                    print(weight)
                    print("")

                    print("Activity: latent")
                    m = []
                    for motion in range(par['num_motion_dirs']):
                        m.append(np.where(inputs[:,2]==motion))
                    temp = np.zeros((par['n_latent'],8))
                    for dir in range(par['num_motion_dirs']):
                        temp[:,dir] = np.round(np.mean(latent_sample[m[dir]], axis=0),3)
                    print(np.round(temp,3))
                    print("")

                    plt.figure()
                    plt.imshow(temp, cmap='inferno')
                    plt.colorbar()
                    plt.savefig('./savedir/latent_act_bottom'+str(i)+'.png')
                    plt.close()

                    ind = np.where(inputs[:,2]==3)[0]
                    act = np.zeros((par['n_neurons'],par['n_neurons']))
                    act2 = np.zeros((par['n_neurons'],par['n_neurons']))
                    for n in range(par['n_neurons']):
                        for m in range(par['n_neurons']):
                            temp = np.intersect1d(np.where(inputs[:,0]==n)[0],np.where(inputs[:,1]==m)[0])
                            temp2 = np.intersect1d(ind, temp)
                            act[n,m] = np.mean(latent_sample[temp2,0])
                            act2[n,m] = np.mean(latent_sample[temp2,1])
                    plt.figure()
                    plt.subplot(1,2,1)
                    plt.imshow(act, cmap='inferno')
                    plt.subplot(1,2,2)
                    plt.imshow(act2, cmap='inferno')
                    plt.colorbar()
                    plt.savefig('./savedir/latent_activity_subset_dir_'+str(i)+'.png')
                    #plt.show()
                    #plt.close()



                    var_dict = sess.run(model.generative_vars)
                    with open('./savedir/generative_var_dict_trial.pkl', 'wb') as vf:
                        pickle.dump(var_dict, vf)

                    # visualization(inputs, neural_inputs)

                    for b in range(10):

                        output_string = ''
                        output_string += '\n--- {} ---\n'.format(b)
                        output_string += 'mu:  {}\n'.format(str(mu[b,:]))
                        output_string += 'sig: {}\n'.format(str(sigma[b,:]))

                        if b == 0:
                            rw = 'w'
                        else:
                            rw = 'a'

                        with open('./savedir/recon_data_iter{}.txt'.format(i,b), rw) as f:
                            f.write(output_string)


                        fig, ax = plt.subplots(2,2,figsize=[8,8])
                        for a in range(2):
                            inp = np.sum(np.reshape(neural_inputs[b], [9,10,10]), axis=a)
                            hat = np.sum(np.reshape(x_hat[b], [9,10,10]), axis=a)

                            ax[a,0].set_title('Actual (Axis {})'.format(a))
                            ax[a,0].imshow(inp, clim=[0,1])
                            ax[a,1].set_title('Reconstructed (Axis {})'.format(a))
                            ax[a,1].imshow(hat, clim=[0,1])

                        plt.savefig('./savedir/recon_iter{}_trial{}.png'.format(i,b))
                        plt.close(fig)

            print("TRAINING ON SUBSET LOC FINISHED\n")
            print("TESTING ON ALL QUADRANTS")
            test(stim, model, 0, sess, x, y, ff=False, gff=True)

            if ['dynamic_training']:
                inp = input("Would you like to continue training? (y/n)\n")
                train = True if (inp in ["y","Y","Yes","yes","True"]) else False

                if train:
                    par['n_train_batches_gen'] = int(input("Enter how many iterations: "))
            else:
                train = False
            print("")

        print("STARTING TO TEST TRAINING ON ALL QUADRANTS")


        for i in range(par['n_train_batches_gen'],par['n_train_batches_gen']+1500):

            name, inputs, neural_inputs, outputs = stim.generate_trial(0, False, False)
            task_info = np.ones([par['batch_size'],1])
            feed_dict = {x:neural_inputs, y:outputs, info:task_info, alpha:0.05}
            _, _, loss, recon_loss, latent_loss, y_hat, x_hat, mu, sigma, latent_sample = sess.run([model.train_op_task, model.train_op_recon, model.task_loss, \
                model.recon_loss, model.latent_loss, model.y, model.x_hat, model.mu, model.si, model.latent_sample], feed_dict=feed_dict)


            if i%50 == 0:
                ind = np.intersect1d(np.argwhere(inputs[:,3]==1), np.argwhere(inputs[:,4]==0))
                acc = get_perf(outputs,y_hat)
                print('{} | Reconstr. Loss: {:.3f} | Latent Loss: {:.3f} | Accuracy: {:.3f} | <Sig>: {:.3f} +/- {:.3f}'.format( \
                    i, recon_loss, latent_loss, acc, np.mean(sigma), np.std(sigma)))
                iteration.append(i)
                accuracy.append(acc)

        # print(iteration)
        # print(accuracy)
        plt.figure()
        plt.plot(iteration, accuracy, '-o', linestyle='-', marker='o',linewidth=2)
        plt.show()
        plt.savefig('./savedir/gen_model_learning_curve.png')
        plt.close()






    print('Complete.')
def main(gpu_id=None, save_fn='test.pkl'):

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    """
    Reset TensorFlow before running anything
    """
    tf.reset_default_graph()
    """
    Create the stimulus class to generate trial paramaters and input activity
    """
    stim = stimulus.MultiStimulus()
    """
    Define all placeholder
    """
    x, target, pol_target, mask, pred_val, actual_action, advantage, mask, gating, = generate_placeholders(
    )

    config = tf.ConfigProto()
    #config.gpu_options.allow_growth=True

    print_key_params()

    with tf.Session(config=config) as sess:

        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, target, pol_target, gating, pred_val,
                          actual_action, advantage, mask)

        sess.run(tf.global_variables_initializer())

        # keep track of the model performance across training
        model_performance = {
            'reward': [],
            'entropy_loss': [],
            'val_loss': [],
            'pol_loss': [],
            'spike_loss': [],
            'trial': [],
            'task': []
        }
        reward_matrix = np.zeros((par['n_tasks'], par['n_tasks']))
        accuracy_full = []

        sess.run(model.reset_prev_vars)

        for task in range(2, par['n_tasks']):
            #for task in [0,3]:
            accuracy_above_threshold = 0

            task_start_time = time.time()

            for i in range(par['n_train_batches']):

                # make batch of training data
                name, input_data, desired_output, mk, reward_data = stim.generate_trial(
                    task)
                mk = mk[..., np.newaxis]
                """
                Run the model
                """
                pol_out_list, val_out_list, h_list, action_list, mask_list, reward_list = sess.run([model.pol_out, model.val_out, model.h, model.action, \
                    model.mask, model.reward], {x: input_data, target: reward_data, mask: mk, gating:par['gating'][task]})
                """
                Unpack all lists, calculate predicted value and advantage functions
                """
                val_out, reward, adv, act, predicted_val, stacked_mask = stack_vars(
                    pol_out_list, val_out_list, reward_list, action_list,
                    mask_list, mk)

                T = 4
                #print(reward.shape)
                #print(act.shape)
                #print(stacked_mask.shape)
                pol_out_stacked = np.stack(pol_out_list)

                mk_pol = stacked_mask * mk * np.abs(adv) * 10.
                pol_out_sm_neg = pol_out_stacked - 999 * act
                pol_out_stacked = (adv >= 0) * pol_out_stacked + (
                    adv < 0) * pol_out_sm_neg
                pol_out_sm = np.exp(pol_out_stacked / T) / np.sum(
                    np.exp(pol_out_stacked / T), axis=2, keepdims=True)


                _, pol_d_out_list = sess.run([model.train_op_d, model.pol_d_out], {x: input_data, target: reward_data, \
                    pol_target: pol_out_sm, mask: mk_pol, gating:par['gating'][task]})

                acc_d = get_perf(desired_output, pol_d_out_list, mk)
                """
                Calculate and apply gradients
                """
                if par['stabilization'] == 'pathint':
                    _, _, pol_loss, val_loss, aux_loss, spike_loss, ent_loss = sess.run([model.train_op, \
                         model.update_current_reward, model.pol_loss, model.val_loss, model.aux_loss, model.spike_loss, \
                        model.entropy_loss], feed_dict = {x:input_data, target:reward_data, \
                        gating:par['gating'][task], mask:mk, pred_val: predicted_val, actual_action: act, advantage:adv})
                    if i > 0:
                        sess.run([model.update_small_omega])
                    sess.run([model.update_previous_reward])

                elif par['stabilization'] == 'EWC':
                    _, pol_loss,val_loss, aux_loss, spike_loss, ent_loss = sess.run([model.train_op, model.pol_loss, \
                        model.val_loss, model.aux_loss, model.spike_loss, model.entropy_loss], feed_dict = \
                        {x:input_data, target:reward_data, gating:par['gating'][task], mask:mk, pred_val: predicted_val, \
                        actual_action: act, advantage:adv})

                acc = np.mean(np.sum(reward > 0, axis=0))
                if acc > 0.99:
                    accuracy_above_threshold += 1
                if accuracy_above_threshold >= 3000:
                    break

                sess.run([model.reset_rnn_weights])
                if i % 10 == 0:
                    #print('Iter ', i, 'Task name ', name, ' accuracy', acc, ' aux loss', aux_loss, 'spike_loss', spike_loss, ' h > 0 ', above_zero, 'mean h', np.mean(h_stacked))
                    print('Iter ', i, 'Task name ', name, ' accuracy', acc,
                          acc_d, ' aux loss', aux_loss, 'time ',
                          np.around(time.time() - task_start_time))

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                """
                _, reset_masks = sess.run([model.reset_shunted_weights, model.reset_masks], feed_dict = \
                    {x:input_data, target: reward_data, gating:par['gating'][task], mask:mk})
                for i in range(len(reset_masks)):
                    print('Mean reset masks ', np.mean(reset_masks[i]))
                """
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])

            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    name, input_data, _, mk, reward_data = stim.generate_trial(
                        task)
                    mk = mk[..., np.newaxis]
                    big_omegas = sess.run([model.update_big_omega,model.big_omega_var], feed_dict = \
                        {x:input_data, target: reward_data, gating:par['gating'][task], mask:mk})

            # Test all tasks at the end of each learning session
            num_reps = 10
            for (task_prime, r) in product(range(par['n_tasks']),
                                           range(num_reps)):

                # make batch of training data
                name, input_data, _, mk, reward_data = stim.generate_trial(
                    task_prime)
                mk = mk[..., np.newaxis]

                reward_list = sess.run([model.reward], feed_dict = {x:input_data, target: reward_data, \
                    gating:par['gating'][task_prime], mask:mk, explore_prob: 0.})
                # TODO: figure out what's with the extra dimension at index 0 in reward
                reward = np.squeeze(np.stack(reward_list))
                reward_matrix[task, task_prime] += np.mean(
                    np.sum(reward > 0, axis=0)) / num_reps

            print('Accuracy grid after task {}:'.format(task))
            print(reward_matrix[task, :])
            results = {'reward_matrix': reward_matrix, 'par': par}
            pickle.dump(results, open(par['save_dir'] + save_fn, 'wb'))
            print('Analysis results saved in ', save_fn)
            print('')

            # Reset the Adam Optimizer, and set the previous parater values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)
Beispiel #8
0
def reinforcement_learning(save_fn='test.pkl', gpu_id=None):
    """ Run reinforcement learning training """

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Reset Tensorflow graph before running anything
    tf.reset_default_graph()

    # Define all placeholders
    x, target, mask, pred_val, actual_action, advantage, mask = generate_placeholders(
    )

    # Set up stimulus and accuracy recording
    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros([par['n_tasks'], par['n_tasks']])
    full_activity_list = []
    model_performance = {
        'reward': [],
        'entropy_loss': [],
        'val_loss': [],
        'pol_loss': [],
        'spike_loss': [],
        'trial': [],
        'task': []
    }
    reward_matrix = np.zeros((par['n_tasks'], par['n_tasks']))

    # Display relevant parameters
    print_key_info()

    # Start Tensorflow session
    with tf.Session() as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            # Check order against args unpacking in model if editing
            model = Model(x, target, mask)

        # Initialize variables and start the timer
        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        # Begin training loop, iterating over tasks
        #for task in range(par['n_tasks']):
        for task in [0, 3, 5, 2]:
            accuracy_iter = []
            task_start_time = time.time()

            for i in range(par['n_train_batches']):

                # Generate a batch of stimulus data for training
                name, input_data, _, mk, reward_data = stim.generate_trial(
                    task)
                mk = mk[..., np.newaxis]

                # Put together the feed dictionary
                feed_dict = {x: input_data, target: reward_data, mask: mk}

                # Calculate and apply gradients
                if par['stabilization'] == 'pathint':
                    _, _, _, pol_loss, val_loss, aux_loss, spike_loss, ent_loss, pred_err, stim_pred_err, \
                        rew_pred_err, act_pred_err, h_list, reward_list, pred_loss, expected_reward, actual_reward = \
                        sess.run([model.train_op, model.update_current_reward, model.update_small_omega, model.pol_loss, model.val_loss, \
                        model.aux_loss, model.spike_loss, model.entropy_loss, model.total_pred_error, model.stim_pred_error, model.rew_pred_error, model.act_pred_error, \
                        model.h, model.reward, model.pred_loss, model.expected_reward_vector, model.actual_reward_vector], feed_dict = feed_dict)
                    if i > 0:
                        sess.run([model.update_small_omega])
                    sess.run([model.update_previous_reward])
                elif par['stabilization'] == 'EWC':
                    _, _, pol_loss,val_loss, aux_loss, spike_loss, ent_loss, pred_err, stim_pred_err, rew_pred_err, act_pred_err, h_list, reward_list = \
                        sess.run([model.train_op, model.update_current_reward, model.pol_loss, model.val_loss, \
                        model.aux_loss, model.spike_loss, model.entropy_loss, model.total_pred_error, model.stim_pred_error, model.rew_pred_error, model.act_pred_error, \
                        model.h, model.reward], feed_dict = feed_dict)

                # Record accuracies
                reward = np.stack(reward_list)
                acc = np.mean(np.sum(reward > 0, axis=0))
                accuracy_iter.append(acc)
                if i > 5000:
                    if np.mean(accuracy_iter[-5000:]) > 0.98 or (
                            i > 25000 and np.mean(accuracy_iter[-20:]) > 0.95):
                        print('Accuracy reached threshold')
                        break

                # Display network performance
                if i % 200 == 0:

                    fig, ax = plt.subplots(1, 3, figsize=[24, 8])
                    im0 = ax[0].imshow(expected_reward[:, :, 0],
                                       aspect='auto',
                                       clim=(-np.abs(expected_reward).max(),
                                             np.abs(expected_reward).max()))
                    ax[0].set_title('Expected Reward')
                    im1 = ax[1].imshow(actual_reward[:, :, 0],
                                       aspect='auto',
                                       clim=(par['fix_break_penalty'],
                                             par['correct_choice_reward']))
                    ax[1].set_title('Actual Reward')
                    diff = expected_reward[:, :, 0] - actual_reward[:, :, 0]
                    im2 = ax[2].imshow(diff,
                                       aspect='auto',
                                       clim=(-np.abs(diff).max(),
                                             np.abs(diff).max()))
                    ax[2].set_title('Expected - Actual')
                    fig.colorbar(im0,
                                 ax=ax[0],
                                 orientation='horizontal',
                                 ticks=[
                                     -np.abs(expected_reward).max(), 0,
                                     np.abs(expected_reward).max()
                                 ])
                    fig.colorbar(im1,
                                 ax=ax[1],
                                 orientation='horizontal',
                                 ticks=[
                                     par['fix_break_penalty'], 0,
                                     par['correct_choice_reward']
                                 ])
                    fig.colorbar(
                        im2,
                        ax=ax[2],
                        orientation='horizontal',
                        ticks=[-np.abs(diff).max(), 0,
                               np.abs(diff).max()])
                    plt.savefig('./savedir/reward_task{}_iter{}_v2.png'.format(
                        task, i))
                    plt.clf()
                    plt.close()

                    pe = [
                        float('{:7.5f}'.format(np.mean(pred_err[i])))
                        for i in range(len(pred_err))
                    ]
                    spe = [
                        float('{:7.5f}'.format(np.mean(stim_pred_err[i])))
                        for i in range(len(stim_pred_err))
                    ]
                    rpe = [
                        float('{:9.7f}'.format(np.mean(rew_pred_err[i])))
                        for i in range(len(rew_pred_err))
                    ]
                    ape = [
                        float('{:7.5f}'.format(np.mean(act_pred_err[i])))
                        for i in range(len(act_pred_err))
                    ]

                    print('Iter: {:>5} | Task: {} | Accuracy: {:5.3f} | Aux Loss: {:7.5f} | Mean h: {:8.5f} | Time: {}'.format(\
                        i, name, acc, aux_loss, np.mean(np.stack(h_list)), int(np.around(time.time() - task_start_time))))
                    print(
                        '            | Pred Error: {:7.5f} | Total PE: {} | Stim PE: {} | Rew PE: {} | Act PE: {}'
                        .format(pred_loss, pe, spe, rpe, ape))

                    #print('Iter ', i, 'Task name ', name, ' accuracy', acc, ' aux loss', aux_loss, ' pred error', pe, 'pred loss', pred_loss, \
                    #'mean h', np.mean(np.stack(h_list)), 'time ', np.around(time.time() - task_start_time))

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])

            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    name, input_data, _, mk, reward_data = stim.generate_trial(
                        task)
                    mk = mk[..., np.newaxis]
                    big_omegas = sess.run([model.update_big_omega,model.big_omega_var], feed_dict = \
                        {x:input_data, target: reward_data, gating:par['gating'][task], mask:mk})

            # Test all tasks at the end of each learning session
            num_reps = 10
            task_activity_list = []
            for task_prime in range(task + 1):
                for r in range(num_reps):

                    # make batch of training data
                    name, input_data, _, mk, reward_data = stim.generate_trial(
                        task_prime)
                    mk = mk[..., np.newaxis]

                    reward_list, h = sess.run([model.reward, model.h],
                                              feed_dict={
                                                  x: input_data,
                                                  target: reward_data,
                                                  mask: mk
                                              })

                    reward = np.squeeze(np.stack(reward_list))
                    reward_matrix[task, task_prime] += np.mean(
                        np.sum(reward > 0, axis=0)) / num_reps

                # Record network activity
                task_activity_list.append(h)

            # Aggregate task after testing each task set
            # Each of [all tasks] elements is [tasks tested, time steps, batch size hidden size]
            full_activity_list.append(task_activity_list)

            print('Accuracy grid after task {}:'.format(task))
            print(reward_matrix[task, :])

            results = {
                'reward_matrix': reward_matrix,
                'par': par,
                'activity': full_activity_list
            }
            pickle.dump(results, open(par['save_dir'] + save_fn, 'wb'))
            print('Analysis results saved in', save_fn)
            print('')

            # Reset the Adam Optimizer, and set the previous parameter values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

    print('\nModel execution complete. (Reinforcement)')
Beispiel #9
0
def supervised_learning(save_fn='test.pkl', gpu_id=None):
    """ Run supervised learning training """

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Reset Tensorflow graph before running anything
    tf.reset_default_graph()

    # Define all placeholders
    x = tf.placeholder(
        tf.float32, [par['num_time_steps'], par['batch_size'], par['n_input']],
        'stim')
    y = tf.placeholder(
        tf.float32,
        [par['num_time_steps'], par['batch_size'], par['n_output']], 'out')
    m = tf.placeholder(tf.float32, [par['num_time_steps'], par['batch_size']],
                       'mask')
    g = tf.placeholder(tf.float32, [par['n_hidden']], 'gating')

    # Set up stimulus and accuracy recording
    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros([par['n_tasks'], par['n_tasks']])
    full_activity_list = []

    # Display relevant parameters
    print_key_info()

    # Start Tensorflow session
    with tf.Session() as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, y, m, g)

        # Initialize variables and start the timer
        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        # Begin training loop, iterating over tasks
        for task in range(par['n_tasks']):
            for i in range(par['n_train_batches']):

                # Generate a batch of stimulus data for training
                name, stim_in, y_hat, mk, _ = stim.generate_trial(task)

                # Put together the feed dictionary
                feed_dict = {
                    x: stim_in,
                    y: y_hat,
                    g: par['gating'][task],
                    m: mk
                }

                # Run the model using one of the available stabilization methods
                if par['stabilization'] == 'pathint':
                    _, _, loss, AL, spike_loss, output = sess.run([model.train_op, \
                        model.update_small_omega, model.pol_loss, model.aux_loss, \
                        model.spike_loss, model.output], feed_dict=feed_dict)
                elif par['stabilization'] == 'EWC':
                    _, loss, AL, output = sess.run([model.train_op, model.pol_loss, \
                        model.aux_loss, model.output], feed_dict=feed_dict)

                # Display network performance
                if i % 500 == 0:
                    acc = get_perf(y_hat, output, mk)
                    print('Iter {} | Task name {} | Accuracy {} | Loss {} | Aux Loss {} | Spike Loss {}'.format(\
                        i, name, acc, loss, AL, spike_loss))

            # Test all tasks at the end of each learning session
            num_reps = 10
            task_activity_list = []
            for task_prime in range(task + 1):
                for r in range(num_reps):

                    # Generate stimulus batch for testing
                    name, stim_in, y_hat, mk, _ = stim.generate_trial(
                        task_prime)

                    # Assemble feed dict and run model
                    feed_dict = {x: stim_in, g: par['gating'][task_prime]}
                    output, h = sess.run([model.output, model.h],
                                         feed_dict=feed_dict)

                    # Record results
                    acc = get_perf(y_hat, output, mk)
                    accuracy_grid[task, task_prime] += acc / num_reps

                # Record network activity
                task_activity_list.append(h)

            # Aggregate task after testing each task set
            # Each of [all tasks] elements is [tasks tested, time steps, batch size hidden size]
            full_activity_list.append(task_activity_list)

            # Display accuracy grid after testing is complete
            print('Accuracy grid after task {}:'.format(task))
            print(accuracy_grid[task, :])
            print()

            # Update big omegas
            if par['stabilization'] == 'pathint':
                _, big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])
            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    name, stim_in, y_hat, mk, _ = stim.generate_trial(task)
                    feed_dict = {x: stim_in, g: par['gating'][task_prime]}
                    _, big_omegas = sess.run([model.update_big_omega, model.big_omega-var], \
                        feed_dict = feed_dict)

            # Reset the Adam Optimizer and save previous parameter values as current ones
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

            # Reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

        if par['save_analysis']:
            save_results = {
                'task': task,
                'accuracy_grid': accuracy_grid,
                'par': par,
                'activity': full_activity_list
            }
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete. (Supervised)')
def main():

    #os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    tf.reset_default_graph()

    x = tf.placeholder(tf.float32,
                       [par['batch_size'], par['forward_shape'][0]], 'stim')
    y = tf.placeholder(tf.float32, [par['batch_size'], par['n_output']], 'out')
    alpha = tf.placeholder(tf.float32, [], 'alpha')

    #with tf.device('/gpu:0'):
    #    model = Model(x, y)
    model = Model(x, y, alpha)

    stim = stimulus.MultiStimulus()

    iteration = []
    accuracy = []

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())
        acc0 = []
        acc1 = []

        for i in range(par['n_train_batches_gen']):

            if i % 2 == 0:
                par['subset_loc'] = True
                alpha_val = 0.2
            else:
                par['subset_loc'] = False
                alpha_val = 0
            name, inputs, neural_inputs, outputs = stim.generate_trial(
                0, par['subset_dirs'], par['subset_loc'])

            feed_dict = {x: neural_inputs, y: outputs, alpha: alpha_val}

            _, task_loss, recon_loss, latent_loss, weight_loss, y_hat, x_hat, mu, sigma, latent_sample, W_sample = sess.run([model.train_op, model.task_loss, \
                model.recon_loss, model.latent_loss, model.weight_loss, model.y, model.x_hat, model.mu, model.si, model.latent_sample, model.W_sample], feed_dict=feed_dict)

            _, latent_loss, latent_sample, pred_sm = sess.run(
                [
                    model.train_op_pred, model.pred_loss, model.latent_sample,
                    model.pred_sm
                ],
                feed_dict=feed_dict)

            acc = get_perf(outputs, y_hat)
            if i % 2 == 0:
                acc0.append(acc)
            else:
                acc1.append(acc)

            if i % 100 == 0:
                print('Mean acc ', np.mean(acc0), ' ', np.mean(acc1))
                acc0 = []
                acc1 = []
                print('{} | Reconstr. Loss: {:.5f} | Latent Loss: {:.5f} | Weight_loss: {:.5f} | Task Loss: {:.5f} | Accuracy: {:.3f} | <Sig>: {:.3f} +/- {:.3f}'.format( \
                    i, recon_loss, latent_loss, weight_loss, task_loss, acc, np.mean(np.abs(sigma)), np.std(sigma)))
                iteration.append(i)
                accuracy.append(acc)
                print('latent mean', np.mean(latent_sample, axis=0))
                print('latent var', np.var(latent_sample, axis=0))
                print('pred 0 ', np.mean(pred_sm[:par['batch_size'], :],
                                         axis=0), ' pred 1 ',
                      np.mean(pred_sm[par['batch_size']:, :], axis=0))

            if i % 100 == 0:

                correlation = np.zeros((par['n_latent'], 7))
                correlation_within = np.zeros(
                    (par['n_latent'], par['n_latent']))
                for l in range(par['n_latent']):
                    # for latent_sample in latent_sample:
                    correlation[l, 0] += pearsonr(latent_sample[:, l],
                                                  inputs[:, 0])[0]  #x
                    correlation[l, 1] += pearsonr(latent_sample[:, l],
                                                  inputs[:, 1])[0]  #y
                    correlation[l, 2] += pearsonr(latent_sample[:, l],
                                                  inputs[:, 2])[0]  #dir_ind
                    correlation[l, 3] += pearsonr(latent_sample[:, l],
                                                  inputs[:, 3])[0]  #m
                    correlation[l, 4] += pearsonr(latent_sample[:, l],
                                                  inputs[:, 4])[0]  #fix
                    correlation[l, 5] += pearsonr(latent_sample[:, l],
                                                  outputs[:, 0])[0]  #motion_x
                    correlation[l, 6] += pearsonr(latent_sample[:, l],
                                                  outputs[:, 1])[0]  #motion_y

                    for l1 in range(par['n_latent']):
                        correlation_within[l, l1] = pearsonr(
                            latent_sample[:, l], latent_sample[:, l1])[0]
                #print(['loc_x','loc_y','dir','m','fix','mot_x','mot_y'])
                #print(np.round(correlation,3))
                #print('Weight matrix from sample...')
                #print(np.round(W_sample,3))
                print('Inter var correlations',
                      np.sum(np.abs(correlation_within)) - par['n_latent'])
                #print(np.round(correlation_within,3))
                print('')
                print('')

        print("TRAINING ON SUBSET LOC FINISHED\n")
        print("TESTING ON ALL QUADRANTS")
        test(stim, model, 0, sess, x, y, ff=False, gff=True)

        print("STARTING TO TEST TRAINING ON ALL QUADRANTS")

        for i in range(par['n_train_batches_gen'],
                       par['n_train_batches_gen'] + 1500):

            name, inputs, neural_inputs, outputs = stim.generate_trial(
                0, False, False)
            feed_dict = {x: neural_inputs, y: outputs, alpha: 0.05}
            _, loss, recon_loss, latent_loss, weight_loss, y_hat, x_hat, mu, sigma, latent_sample = sess.run([model.train_op, model.task_loss, \
                model.recon_loss, model.latent_loss, model.weight_cost, model.y, model.x_hat, model.mu, model.si, model.latent_sample], feed_dict=feed_dict)

            if i % 50 == 0:
                ind = np.intersect1d(np.argwhere(inputs[:, 3] == 1),
                                     np.argwhere(inputs[:, 4] == 0))
                acc = get_perf(outputs, y_hat)
                print('{} | Reconstr. Loss: {:.3f} | Latent Loss: {:.3f} | Weight_loss: {:.3f} | Accuracy: {:.3f} | <Sig>: {:.3f} +/- {:.3f}'.format( \
                    i, recon_loss, latent_loss, weight_loss, acc, np.mean(sigma), np.std(sigma)))
                iteration.append(i)
                accuracy.append(acc)

        print(iteration)
        print(accuracy)
        plt.figure()
        plt.plot(iteration,
                 accuracy,
                 '-o',
                 linestyle='-',
                 marker='o',
                 linewidth=2)
        plt.show()
        plt.savefig('./savedir/gen_model_learning_curve.png')
        plt.close()

        # if i%500 == 0:

        #     correlation = np.zeros((par['n_latent'], 7))
        #     for l in range(par['n_latent']):
        #     # for latent_sample in latent_sample:
        #         correlation[l,0] += pearsonr(latent_sample[:,l], inputs[:,0])[0] #x
        #         correlation[l,1] += pearsonr(latent_sample[:,l], inputs[:,1])[0] #y
        #         correlation[l,2] += pearsonr(latent_sample[:,l], inputs[:,2])[0] #dir_ind
        #         correlation[l,3] += pearsonr(latent_sample[:,l], inputs[:,3])[0] #m
        #         correlation[l,4] += pearsonr(latent_sample[:,l], inputs[:,4])[0] #fix
        #         correlation[l,5] += pearsonr(latent_sample[:,l], outputs[:,0])[0] #motion_x
        #         correlation[l,6] += pearsonr(latent_sample[:,l], outputs[:,1])[0] #motion_y
        #     print(['loc_x','loc_y','dir','m','fix','mot_x','mot_y'])
        #     print(np.round(correlation,3))

        # m = []
        # for motion in range(par['num_motion_dirs']):
        #     m.append(np.where(inputs[:,2]==motion))

        # temp = np.zeros((par['n_latent'],8))
        # for dir in range(par['num_motion_dirs']):
        #     temp[:,dir] = np.round(np.mean(latent_sample[m[dir]], axis=0),3)
        # plt.imshow(temp, cmap='inferno')
        # plt.colorbar()
        # plt.show()

        # var_dict = sess.run(model.generative_vars)
        # with open('./savedir/generative_var_dict_trial.pkl', 'wb') as vf:
        #     pickle.dump(var_dict, vf)

        # visualization(inputs, neural_inputs)

        # for b in range(10):

        #     output_string = ''
        #     output_string += '\n--- {} ---\n'.format(b)
        #     output_string += 'mu:  {}\n'.format(str(mu[b,:]))
        #     output_string += 'sig: {}\n'.format(str(sigma[b,:]))

        #     if b == 0:
        #         rw = 'w'
        #     else:
        #         rw = 'a'

        #     with open('./savedir/recon_data_iter{}.txt'.format(i,b), rw) as f:
        #         f.write(output_string)

        #     fig, ax = plt.subplots(2,2,figsize=[8,8])
        #     for a in range(2):
        #         inp = np.sum(np.reshape(neural_inputs[b], [9,10,10]), axis=a)
        #         hat = np.sum(np.reshape(x_hat[b], [9,10,10]), axis=a)

        #         ax[a,0].set_title('Actual (Axis {})'.format(a))
        #         ax[a,0].imshow(inp, clim=[0,1])
        #         ax[a,1].set_title('Reconstructed (Axis {})'.format(a))
        #         ax[a,1].imshow(hat, clim=[0,1])

        #     plt.savefig('./savedir/recon_iter{}_trial{}.png'.format(i,b))
        #     plt.close(fig)

    print('Complete.')
Beispiel #11
0
def main(save_fn=None, gpu_id=None):

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()
    f = open("./generative_var_dict_trial.pkl", "rb")
    par['var_dict'] = pickle.load(f)

    # Create placeholders for the model
    x = tf.placeholder(tf.float32, [par['batch_size'], par['n_input']], 'stim')
    target = tf.placeholder(tf.float32, [par['batch_size'], par['n_output']],
                            'out')
    ys = tf.placeholder(tf.float32, [par['n_ys'], 2], 'stim_y')

    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))
    accuracy_grid_slow = np.zeros((par['n_tasks'], par['n_tasks']))


    key_info = ['synapse_config','spike_cost','weight_cost','entropy_cost','omega_c','omega_xi',\
        'constrain_input_weights','num_sublayers','n_hidden','noise_rnn_sd','learning_rate','gating_type', 'gate_pct']
    print('Key info')
    for k in key_info:
        print(k, ' ', par[k])

    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True

    iteration = []
    accuracy = []

    # Model run session
    with tf.Session(config=config) as sess:

        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, target, ys)

        sess.run(tf.global_variables_initializer())
        t_start = time.time()

        for task in range(0, par['n_tasks']):

            #################################
            ###     Training FF model     ###
            #################################
            print('FF Model execution starting.\n')
            for i in range(par['n_train_batches']):

                # make batch of training data
                name, stim_real, stim_in, y_hat = stim.generate_trial(
                    task,
                    subset_dirs=par['subset_dirs_ff'],
                    subset_loc=par['subset_loc_ff'])

                # train just ff weights
                _, ff_loss, ff_output = sess.run(
                    [model.train_op_ff, model.ff_loss, model.ff_output],
                    feed_dict={
                        x: stim_in,
                        target: y_hat
                    })

                if i % 50 == 0:
                    ind = np.intersect1d(np.argwhere(stim_real[:, 3] == 1),
                                         np.argwhere(stim_real[:, 4] == 0))
                    ff_acc = get_perf(y_hat[ind], ff_output[ind])
                    # for b in range(20):
                    # print("m: ", stim_real[b,3], ", fix: ": stim_real[b,4])
                    # print("y_hat: ", y_hat[b], ", output: ", ff_output[b], "\n")
                    print('Iter ', i, 'Task name ', name, ' accuracy', ff_acc,
                          ' loss ', ff_loss)
                    iteration.append(i)
                    accuracy.append(ff_acc)
            print('FF Model execution complete.\n')

            # Test all tasks at the end of each learning session
            print("FF Testing Phase")
            test(stim, model, task, sess, x, ys, ff=True, gff=False)

            print("FF TRAINING ON ALL QUADRANTS")
            for i in range(par['n_train_batches'],
                           par['n_train_batches'] + 1500):

                # make batch of training data
                name, stim_real, stim_in, y_hat = stim.generate_trial(
                    task, subset_dirs=False, subset_loc=False)

                # train just ff weights
                _, ff_loss, ff_output = sess.run(
                    [model.train_op_ff, model.ff_loss, model.ff_output],
                    feed_dict={
                        x: stim_in,
                        target: y_hat
                    })

                if i % 50 == 0:
                    ind = np.intersect1d(np.argwhere(stim_real[:, 3] == 1),
                                         np.argwhere(stim_real[:, 4] == 0))
                    ff_acc = get_perf(y_hat[ind], ff_output[ind])
                    # for b in range(20):
                    # print("m: ", stim_real[b,3], ", fix: ": stim_real[b,4])
                    # print("y_hat: ", y_hat[b], ", output: ", ff_output[b], "\n")
                    print('Iter ', i, 'Task name ', name, ' accuracy', ff_acc,
                          ' loss ', ff_loss)
                    iteration.append(i)
                    accuracy.append(ff_acc)
            print('FF Model execution complete.\n')
            test(stim, model, task, sess, x, ys, ff=True, gff=False)

            print(iteration)
            print(accuracy)
            plt.figure()
            plt.plot(iteration,
                     accuracy,
                     '-o',
                     linestyle='-',
                     marker='o',
                     linewidth=2)
            plt.show()
            plt.savefig('./savedir/ff_model_learning_curve.png')
            quit()

            ################################
            ### Training Connected Model ###
            ################################
            # print('Connected Model execution starting.\n')
            # x_hats = []
            # y_samples = []
            # for i in range(par['n_train_batches_full']):

            #     # make batch of training data
            #     name, stim_real, stim_in, y_hat = stim.generate_trial(task, subset_dirs=par['subset_dirs'], subset_loc=par['subset_loc'])
            #     ind = np.random.choice(np.arange(par['batch_size']), size=par['n_ys'])
            #     stim_real = stim_real[ind]
            #     stim_in = stim_in[ind]
            #     y_sample = y_hat[ind]

            #     # train just the conn weights
            #     _, full_loss, latent_loss, full_output, x_hat, mu, si = sess.run([model.train_op_full, model.full_loss, model.latent_loss, model.full_output, model.x_hat, model.mu, model.si], feed_dict = {ys: y_sample})

            #     if i%100 == 0:
            #         conn_acc = get_perf(y_sample, full_output, ff=False)
            #         print('Iter ', i, 'Task name ', name, ' accuracy', conn_acc, ' loss ', full_loss, ' latent_loss ',latent_loss, ' mu ', [np.mean(mu), np.std(mu)], ' si ', [np.mean(si), np.std(si)])
            #     if i%500 == 0 and i!=0:
            #         visualization(stim_real, x_hat, y_sample, full_output, i)

            #     # if i > 500:
            #         # x_hats.append(x_hat)
            #         # y_samples.append(y_sample)

            # print('Connected Model execution complete.\n')

            # # Test all tasks at the end of each learning session
            # # print("Connected Model Testing Phase")
            # # test(stim, model, task, sess, x, ys, ff=True)

            # #####################################
            # ### Training Based on X_hat Model ###
            # #####################################
            # # print('Connected Model execution starting.\n')
            # # for i in range(len(x_hats)):

            # #     # make batch of training data
            # #     # ind = np.random.choice(np.arange(par['batch_size']), size=256)
            # #     x_hat = np.reshape(x_hats[i], (256,9,10,10))
            # #     x_hat[:,5:,5:] = 0
            # #     stim_in = np.reshape(x_hat, (256,900))
            # #     y_hat = y_samples[i]
            # #     # y_hat[ind] = y_samples[i][ind]

            # #     # train just ff weights
            # #     _, ff_loss, ff_output = sess.run([model.train_op_ff, model.ff_loss, model.ff_output], feed_dict = {x:stim_in, target:y_hat})

            # #     if i%50 == 0:
            # #         ff_acc = get_perf(y_hat, ff_output, ff=True)
            # #         print('Iter ', i, 'Task name ', name, ' accuracy', ff_acc, ' loss ', ff_loss)
            # # print('Connected Model execution complete.\n')

            # # # Test all tasks at the end of each learning session
            # # print("FF Testing Phase Final")
            # # test(stim, model, task, sess, x, ys, ff=True)

            # # Reset the Adam Optimizer, and set the previous parater values to their current values
            # sess.run(model.reset_adam_op_ff)
            # sess.run(model.reset_adam_op_full)
            # if par['stabilization'] == 'pathint':
            #     sess.run(model.reset_small_omega)

            # # reset weights between tasks if called upon
            # if par['reset_weights']:
            #     sess.run(model.reset_weights_ff)
            #     sess.run(model.reset_weights_full)

        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'big_omegas': big_omegas, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

        print('\nModel execution complete.\n')
Beispiel #12
0
def main(gpu_id=None):

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    tf.reset_default_graph()
    stim = stimulus.MultiStimulus()

    mask = tf.placeholder(tf.float32,
                          shape=[par['num_time_steps'], par['batch_size']])
    x = tf.placeholder(tf.float32,
                       shape=[
                           par['num_time_steps'],
                           par['batch_size'],
                           par['n_input'],
                       ])  # input data
    y = tf.placeholder(
        tf.float32,
        shape=[par['num_time_steps'], par['batch_size'],
               par['n_output']])  # target data

    # enter "config=tf.ConfigProto(log_device_placement=True)" inside Session to check whether CPU/GPU in use
    with tf.Session() as sess:

        #with tf.device("/gpu:0"):
        model = Model(x, y, mask)
        init = tf.global_variables_initializer()
        sess.run(init)
        t_start = time.time()

        model_performance = {'accuracy': [], 'loss': [], 'perf_loss': [], 'spike_loss': [], \
            'recotruction_loss': [], 'trial': [], 'time': []}

        task = 17

        for i in range(par['n_train_batches']):

            # generate batch of N (batch_size X num_batches) trials
            name, stim_in, target_data, train_mask, _ = stim.generate_trial(
                task)

            _, loss, perf_loss, recotruction_loss, x_hat, output, h, latent = \
                sess.run([model.train_op, model.loss, model.perf_loss, model.recotruction_loss, \
                model.x_hat, model.output, model.h, model.sample_latent], \
                {x: stim_in, y: target_data, mask: train_mask})

            accuracy = get_perf(target_data, output, train_mask)

            iteration_time = time.time() - t_start

            #model_performance = append_model_performance(model_performance, accuracy, loss, perf_loss, \
            #    recotruction_loss, (i+1)*N, iteration_time)
            """
            Save the network model and output model performance to screen
            """
            if i % 100 == 0:
                print_results(i, iteration_time, perf_loss, recotruction_loss,
                              h, accuracy)
                weights = sess.run([model.var_dict])
                pickle.dump(weights[0],
                            open('./savedir/saved_weights.pkl', 'wb'))
                print('Weights saved')

                if i % 1000 == 0:
                    x_hat = np.stack(x_hat, axis=0)
                    f = plt.figure(figsize=(8, 4))
                    for k in range(2):
                        ax = f.add_subplot(2, 2, 1 + k * 2)
                        ax.imshow(stim_in[:, k, :], aspect='auto')
                        ax = f.add_subplot(2, 2, 2 + k * 2)
                        ax.imshow(x_hat[-1::-1, k, :], aspect='auto')
                    plt.show()
        """
        Analyze the network model and save the results
        """
        if par['analyze_model']:
            weights = eval_weights()
            analysis.analyze_model(trial_info, y_hat, x_hat, latent,
                                   state_hist, model_performance, weights)
Beispiel #13
0
def main(save_fn=None, gpu_id=None):
    """ Run supervised learning training """

    savedir = par['save_dir']

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Reset Tensorflow graph before running anything
    tf.reset_default_graph()

    # Define all placeholders
    x = tf.placeholder(
        tf.float32, [par['num_time_steps'], par['batch_size'], par['n_input']],
        'stim')
    y = tf.placeholder(
        tf.float32,
        [par['num_time_steps'], par['batch_size'], par['n_output']], 'out')
    m = tf.placeholder(tf.float32,
                       [par['num_time_steps'], par['batch_size'], 1], 'mask')

    # Set up stimulus
    stim = stimulus.MultiStimulus()

    # Start Tensorflow session
    with tf.Session() as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, y, m)

        # Initialize variables and start the timer
        sess.run(tf.global_variables_initializer())
        t_start = time.time()

        print('\nStarting training.')

        # Training autoencoder
        print('\nRunning VAE:')
        for i in range(par['num_autoencoder_batches']):

            # Generate a batch of stimulus data for training
            # and put together the model's feed dictionary
            name, stim_in, y_hat, mk, _ = stim.generate_trial(0)
            feed_dict = {x: stim_in, y: y_hat, m: mk[..., np.newaxis]}

            # Run the model
            _, recon_loss, latent_loss = sess.run(
                [model.train_VAE, model.recon_loss, model.act_latent_loss],
                feed_dict=feed_dict)

            if i % 200 == 0:
                print('{:4} | Recon: {:5.3f} | Lat: {:5.3f}'.format(
                    i, recon_loss, latent_loss))

        sess.run(model.reset_adam_op)

        # Training generative adversarial network
        print('\nRunning GAN:')
        for i in range(par['num_GAN_batches']):

            for j in range(3):
                if j == 0:
                    trainer = model.train_generator
                    curr = 'G'
                else:
                    trainer = model.train_discriminator
                    curr = 'D'

                # Generate a batch of stimulus data for training
                # and put together the model's feed dictionary
                name, stim_in, y_hat, mk, _ = stim.generate_trial(0)
                feed_dict = {x: stim_in, y: y_hat, m: mk[..., np.newaxis]}

                # Run the model
                _, gen_loss, discr_loss, gen_latent, var_loss = sess.run([trainer, model.generator_loss, \
                    model.discriminator_loss, model.gen_latent_loss, model.gen_var_loss], feed_dict=feed_dict)

                if i % 200 == 0 and j in [0, 2]:

                    outputs_all = sess.run(model.outputs_dict,
                                           feed_dict=feed_dict)

                    gen_out = np.argmax(softmax(
                        np.stack(outputs_all['generator_to_discriminator'])),
                                        axis=-1)
                    enc_out = np.argmax(softmax(
                        np.stack(outputs_all['encoder_to_discriminator'])),
                                        axis=-1)
                    outputs = outputs_all['generator_to_decoder']

                    acc_dis_gen = np.mean(
                        gen_out == np.argmax(par['discriminator_gen_target']))
                    acc_dis_act = np.mean(
                        enc_out == np.argmax(par['discriminator_act_target']))

                    print('{:4} | {} | Gen: {:7.5f} | Discr: {:7.5f} | Lat: {:7.5f} | Var: {:7.5} | Corr. Real: {:5.3f} | Corr. Gen: {:5.3f}'.format( \
                        i, curr, gen_loss, discr_loss, gen_latent, var_loss, acc_dis_act, acc_dis_gen))
                    outputs = np.stack(outputs, axis=0)

                    fig, ax = plt.subplots(2, 2, figsize=[8, 8])

                    ax[0, 0].set_title('Example A')
                    ax[0, 0].imshow(outputs[:, 0, :], clim=[0, 4])
                    ax[0, 1].set_title('Example B')
                    ax[0, 1].imshow(outputs[:, 1, :], clim=[0, 4])
                    ax[1, 0].set_title('Example C')
                    ax[1, 0].imshow(outputs[:, 2, :], clim=[0, 4])
                    ax[1, 1].set_title('Example D')
                    ax[1, 1].imshow(outputs[:, 3, :], clim=[0, 4])

                    plt.savefig(savedir + 'gan_output_{}.png'.format(i))
                    plt.clf()
                    plt.close()

        sess.run(model.reset_adam_op)

        # Training partial task
        print('\nRunning FF (partial): (Accuracy on SOME directions.)')
        for i in range(par['num_train_batches']):

            # Generate a batch of stimulus data for training
            # and put together the model's feed dictionary
            name, stim_in, y_hat, mk, _ = stim.generate_trial(0, partial=True)
            feed_dict = {x: stim_in, y: y_hat, m: mk[..., np.newaxis]}

            # Feed dict
            _, task_loss, outputs = sess.run([
                model.train_task, model.task_loss,
                model.outputs_dict['encoder_to_solution']
            ],
                                             feed_dict=feed_dict)

            response = np.stack(outputs)
            acc = np.mean(
                np.float32(
                    mk *
                    np.argmax(response, axis=-1) == np.argmax(y_hat, axis=-1)))

            if i % 200 == 0:
                print('{:4} | Loss: {:7.5f} | Acc: {:5.3}'.format(
                    i, task_loss, acc))

                fig, ax = plt.subplots(2, 2, figsize=[8, 8])

                ax[0, 0].set_title('Actual A')
                ax[0, 0].imshow(y_hat[:, 0, :], clim=[0, 4])
                ax[0, 1].set_title('Output A')
                ax[0, 1].imshow(response[:, 0, :], clim=[0, 4])
                ax[1, 0].set_title('Actual B')
                ax[1, 0].imshow(y_hat[:, 1, :], clim=[0, 4])
                ax[1, 1].set_title('Output B')
                ax[1, 1].imshow(response[:, 1, :], clim=[0, 4])

                plt.savefig(savedir + 'ff_output_{}.png'.format(i))
                plt.clf()
                plt.close()

        sess.run(model.reset_adam_op)

        # Training entropy maximization
        print('\nRunning entropy maximization: (Accuracy on ALL directions.)')
        acc_list = []
        for i in range(par['num_entropy_batches']):

            # Generate a batch of stimulus data for training
            # and put together the model's feed dictionary
            name, stim_in, y_hat, mk, _ = stim.generate_trial(0, partial=False)
            feed_dict = {x: stim_in, y: y_hat, m: mk[..., np.newaxis]}

            # Feed dict
            _, entropy_loss, task_loss, outputs, W_out = sess.run([model.train_task_entropy, model.entropy_loss, \
                model.task_loss, model.outputs_dict['encoder_to_solution'], model.var_dict['solution']['W_out']],feed_dict=feed_dict)

            response = np.stack(outputs)
            acc = np.mean(
                np.float32(
                    mk *
                    np.argmax(response, axis=-1) == np.argmax(y_hat, axis=-1)))

            if i % 200 == 0:
                np.save(savedir + 'W_out_iter{}'.format(i), W_out)
                acc_list.append(acc)
                print(
                    '{:4} | Entropy Loss: {:7.5f} | Task Loss: {:7.5f} | Acc: {:5.3}'
                    .format(i, entropy_loss, task_loss, acc))

                fig, ax = plt.subplots(2, 2, figsize=[8, 8])

                ax[0, 0].set_title('Actual A')
                ax[0, 0].imshow(y_hat[:, 0, :], clim=[0, 4])
                ax[0, 1].set_title('Output A')
                ax[0, 1].imshow(response[:, 0, :], clim=[0, 4])
                ax[1, 0].set_title('Actual B')
                ax[1, 0].imshow(y_hat[:, 1, :], clim=[0, 4])
                ax[1, 1].set_title('Output B')
                ax[1, 1].imshow(response[:, 1, :], clim=[0, 4])

                plt.savefig(savedir + 'post_ent_output_{}.png'.format(i))
                plt.clf()
                plt.close()

    model_complete(acc_list)
def reinforcement_learning(save_fn='test.pkl', gpu_id=None):
    """ Run reinforcement learning training """

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Reset Tensorflow graph before running anything
    tf.reset_default_graph()

    # Define all placeholders
    x, target, mask, pred_val, actual_action, \
        advantage, mask, gating = generate_placeholders()

    # Set up stimulus and accuracy recording
    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros([par['n_tasks'], par['n_tasks']])
    full_activity_list = []
    model_performance = {
        'reward': [],
        'entropy_loss': [],
        'val_loss': [],
        'pol_loss': [],
        'spike_loss': [],
        'trial': [],
        'task': []
    }
    reward_matrix = np.zeros((par['n_tasks'], par['n_tasks']))

    # Display relevant parameters
    print_key_info()

    # Start Tensorflow session
    with tf.Session() as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            # Check order against args unpacking in model if editing
            model = Model(x, target, mask, gating)

        # Initialize variables and start the timer
        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        # Begin training loop, iterating over tasks
        for task in range(par['n_tasks']):
            accuracy_iter = []
            task_start_time = time.time()

            for i in range(par['n_train_batches']):

                # Generate a batch of stimulus data for training
                name, input_data, _, mk, reward_data = stim.generate_trial(
                    task)
                mk = mk[..., np.newaxis]

                # Put together the feed dictionary
                feed_dict = {
                    x: input_data,
                    target: reward_data,
                    mask: mk,
                    gating: par['gating'][task]
                }

                # Calculate and apply gradients
                if par['stabilization'] == 'pathint':
                    _, _, _, pol_loss, val_loss, aux_loss, spike_loss, ent_loss, h_list, reward_list = \
                        sess.run([model.train_op, model.update_current_reward, model.update_small_omega, model.pol_loss, model.val_loss, \
                        model.aux_loss, model.spike_loss, model.entropy_loss, model.h, model.reward], feed_dict = feed_dict)
                    if i > 0:
                        sess.run([model.update_small_omega])
                    sess.run([model.update_previous_reward])

                # Record accuracies
                reward = np.stack(reward_list)
                acc = np.mean(np.sum(reward > 0, axis=0))
                accuracy_iter.append(acc)
                if i > 2000:
                    if np.mean(accuracy_iter[-2000:]) > 0.985 or (
                            i > 25000
                            and np.mean(accuracy_iter[-2000:]) > 0.98):
                        print('Accuracy reached threshold')
                        break

                # Display network performance
                if i % 500 == 0:
                    print('Iter ', i, 'Task name ', name, ' accuracy', acc, ' aux loss', aux_loss, \
                    'mean h', np.mean(np.stack(h_list)), 'time ', np.around(time.time() - task_start_time))

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])

            # Test all tasks at the end of each learning session
            num_reps = 10
            task_activity_list = []
            for task_prime in range(task + 1):
                for r in range(num_reps):

                    # make batch of training data
                    name, input_data, _, mk, reward_data = stim.generate_trial(
                        task_prime)
                    mk = mk[..., np.newaxis]

                    reward_list, h = sess.run([model.reward, model.h], feed_dict = {x:input_data, target: reward_data, \
                        gating:par['gating'][task_prime], mask:mk})

                    reward = np.squeeze(np.stack(reward_list))
                    reward_matrix[task, task_prime] += np.mean(
                        np.sum(reward > 0, axis=0)) / num_reps

                # Record network activity
                task_activity_list.append(h)

            # Aggregate task after testing each task set
            # Each of [all tasks] elements is [tasks tested, time steps, batch size hidden size]
            full_activity_list.append(task_activity_list)

            print('Accuracy grid after task {}:'.format(task))
            print(reward_matrix[task, :])

            results = {
                'reward_matrix': reward_matrix,
                'par': par,
                'activity': full_activity_list
            }
            pickle.dump(results, open(par['save_dir'] + save_fn, 'wb'))
            print('Analysis results saved in', save_fn)
            print('')

            # Reset the Adam Optimizer, and set the previous parameter values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

    print('\nModel execution complete. (Reinforcement)')
Beispiel #15
0
def main(save_fn=None, gpu_id=None):

    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # train the convolutional layers with the CIFAR-10 dataset
    # otherwise, it will load the convolutional weights from the saved file
    if (par['task'] == 'cifar' or par['task']
            == 'imagenet') and par['train_convolutional_layers']:
        convolutional_layers.ConvolutionalLayers()

    print('\nRunning model.\n')

    # Reset TensorFlow graph
    tf.reset_default_graph()

    # Create placeholders for the model
    # input_data, target_data, gating, mask

    x = tf.placeholder(
        tf.float32, [par['num_time_steps'], par['batch_size'], par['n_input']],
        'stim')
    target = tf.placeholder(
        tf.float32,
        [par['num_time_steps'], par['batch_size'], par['n_output']], 'out')
    mask = tf.placeholder(tf.float32,
                          [par['num_time_steps'], par['batch_size']], 'mask')
    gating = tf.placeholder(tf.float32, [par['n_hidden']], 'gating')

    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros((par['n_tasks'], par['n_tasks']))


    key_info = ['synapse_config','spike_cost','weight_cost','entropy_cost','omega_c','omega_xi',\
        'constrain_input_weights','num_sublayers','n_hidden','noise_rnn_sd','learning_rate','gating_type', 'gate_pct']
    print('Key info')
    for k in key_info:
        print(k, ' ', par[k])

    config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, target, gating, mask)

        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        for task in range(0, par['n_tasks']):

            for i in range(par['n_train_batches']):

                # make batch of training data
                name, stim_in, y_hat, mk, _ = stim.generate_trial(task)

                if par['stabilization'] == 'pathint':
                    _, _, loss, AL, spike_loss, ent_loss, output = sess.run([model.train_op, \
                        model.update_small_omega, model.task_loss, model.aux_loss, model.spike_loss, \
                        model.entropy_loss, model.output], \
                        feed_dict = {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})
                    sess.run([model.reset_rnn_weights])
                    if loss < 0.005 and AL < 0.0004 + 0.0002 * task:
                        break

                elif par['stabilization'] == 'EWC':
                    _, loss, AL = sess.run([model.train_op, model.task_loss, model.aux_loss], feed_dict = \
                        {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})

                if i % 100 == 0:
                    acc = get_perf(y_hat, output, mk)
                    print('Iter ', i, 'Task name ', name, ' accuracy', acc, ' loss ', loss, ' aux loss', AL, ' spike loss', spike_loss, \
                        ' entropy loss', ent_loss)

            # Test all tasks at the end of each learning session
            num_reps = 10
            for (task_prime, r) in product(range(task + 1), range(num_reps)):

                # make batch of training data
                name, stim_in, y_hat, mk, _ = stim.generate_trial(task_prime)

                output, _ = sess.run([model.output, model.syn_x_hist],
                                     feed_dict={
                                         x: stim_in,
                                         gating: par['gating'][task_prime]
                                     })
                acc = get_perf(y_hat, output, mk)
                accuracy_grid[task, task_prime] += acc / num_reps

            print('Accuracy grid after task {}:'.format(task))
            print(accuracy_grid[task, :])
            print('')

            # Update big omegaes, and reset other values before starting new task
            if par['stabilization'] == 'pathint':
                big_omegas = sess.run(
                    [model.update_big_omega, model.big_omega_var])
            elif par['stabilization'] == 'EWC':
                for n in range(par['EWC_fisher_num_batches']):
                    name, stim_in, y_hat, mk, _ = stim.generate_trial(task)
                    big_omegas = sess.run([model.update_big_omega,model.big_omega_var], feed_dict = \
                        {x:stim_in, target:y_hat, gating:par['gating'][task], mask:mk})

            # Reset the Adam Optimizer, and set the previous parater values to their current values
            sess.run(model.reset_adam_op)
            sess.run(model.reset_prev_vars)
            if par['stabilization'] == 'pathint':
                sess.run(model.reset_small_omega)

            # reset weights between tasks if called upon
            if par['reset_weights']:
                sess.run(model.reset_weights)

        if par['save_analysis']:
            save_results = {'task': task, 'accuracy': accuracy, 'accuracy_full': accuracy_full, \
                            'accuracy_grid': accuracy_grid, 'big_omegas': big_omegas, 'par': par}
            pickle.dump(save_results, open(par['save_dir'] + save_fn, 'wb'))

    print('\nModel execution complete.')
Beispiel #16
0
def supervised_learning(save_fn='test.pkl', gpu_id=None):
    """ Run supervised learning training """

    # Isolate requested GPU
    if gpu_id is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Reset Tensorflow graph before running anything
    tf.reset_default_graph()

    # Define all placeholders
    x, y, m, g, trial_mask, lid = get_supervised_placeholders()

    # Set up stimulus and accuracy recording
    stim = stimulus.MultiStimulus()
    accuracy_full = []
    accuracy_grid = np.zeros([par['n_tasks'], par['n_tasks']])
    full_activity_list = []

    # Display relevant parameters
    print('\nRunning model with savename: {}'.format(save_fn))
    print_key_info()

    # Start Tensorflow session
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8
                                ) if gpu_id == '0' else tf.GPUOptions()
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

        # Select CPU or GPU
        device = '/cpu:0' if gpu_id is None else '/gpu:0'
        with tf.device(device):
            model = Model(x, y, m, g, trial_mask, lid)

        # Initialize variables and start the timer
        saver = tf.train.Saver(max_to_keep=100)
        sess.run(tf.global_variables_initializer())
        t_start = time.time()
        sess.run(model.reset_prev_vars)

        # Begin training loop, iterating over tasks
        task = 0  # For legacy comptability
        accuracy_record = []
        for i in range(par['n_train_batches']):

            if par['do_k_shot_testing'] and par['load_from_checkpoint']:
                break

            stims = []
            hats = []
            mks = []
            for t in range(par['n_tasks']):
                _, stim_in, hat, mk, _ = stim.generate_trial(t)
                stims.append(stim_in)
                hats.append(hat)
                mks.append(mk)

            stims = np.stack(stims, axis=0)
            hats = np.stack(hats, axis=0)
            mks = np.stack(mks, axis=0)

            base_inds = np.setdiff1d(np.arange(par['n_tasks']), par['k_shot_task']) if \
                par['do_k_shot_testing'] else np.arange(par['n_tasks'])
            inds = np.random.choice(base_inds, size=[par['batch_size']])

            stim_in = np.zeros(
                [par['num_time_steps'], par['batch_size'], par['n_input']])
            y_hat = np.zeros(
                [par['num_time_steps'], par['batch_size'], par['n_output']])
            mk = np.zeros([par['num_time_steps'], par['batch_size']])
            for b in range(par['batch_size']):
                stim_in[:, b, :] = stims[inds[b], :, b, :]
                y_hat[:, b, :] = hats[inds[b], :, b, :]
                mk[:, b] = mks[inds[b], :, b]

            # Put together the feed dictionary
            feed_dict = {x: stim_in, y: y_hat, g: par['gating'][0], m: mk}

            # Run the model using one of the available stabilization methods
            if par['stabilization'] == 'pathint':
                _, _, loss, AL, weight_loss, spike_loss, output, hidden = sess.run([model.train_op, \
                    model.update_small_omega, model.pol_loss, model.aux_loss, \
                    model.weight_loss, model.spike_loss, model.output, model.h], feed_dict=feed_dict)
            elif par['stabilization'] == 'EWC':
                _, loss, AL, output = sess.run([model.train_op, model.pol_loss, \
                    model.aux_loss, model.output], feed_dict=feed_dict)

            # Display network performance
            if i % 10 == 0:
                acc = get_perf(y_hat, output, mk)
                print('Iter {} | Accuracy {:5.3f} | Loss {:5.3f} | Weight Loss {:5.3f} | Mean Activity {:5.3f} +/- {:5.3}'.format(\
                    i, acc, loss, weight_loss, np.mean(hidden), np.std(hidden)))

                task_accs = []
                task_grads = []
                task_states = []
                off_task_accs = []
                for t in range(par['n_tasks']):
                    _, stim_in, y_hat, mk, _ = stim.generate_trial(t)
                    output, batch_grads, h = sess.run(
                        [model.output, model.batch_grads, model.h],
                        feed_dict={
                            x: stim_in,
                            y: y_hat,
                            g: par['gating'][task],
                            m: mk
                        })

                    perf = get_perf(y_hat, output, mk)
                    if t != par['k_shot_task']:
                        off_task_accs.append(perf)
                    task_accs.append(perf)
                    task_grads.append(batch_grads)
                    task_states.append(h)

                accuracy_record.append(task_accs)
                pickle.dump(
                    accuracy_record,
                    open('./savedir/training_accuracy_{}.pkl'.format(save_fn),
                         'wb'))

                print('Task accuracies:',
                      *['| {:5.3f}'.format(el) for el in task_accs])
                print('Trained Tasks Mean:', np.mean(off_task_accs), '\n')

                if par['use_threshold'] and np.mean(off_task_accs) > 0.95:
                    print('Trained tasks 95\% accuracy threshold reached.')
                    break

        if not par['load_from_checkpoint']:
            print('Saving states, parameters, and weights...')
            pickle.dump(task_states,
                        open('./savedir/states_{}.pkl'.format(save_fn), 'wb'))
            pickle.dump(
                {
                    'parameters': par,
                    'weights': sess.run(model.var_dict)
                }, open('./weights/weights_for_' + save_fn + '.pkl', 'wb'))
            print('States, parameters, and weights saved.')

        if par['do_k_shot_testing']:

            print('\nStarting k-shot testing for task {}.'.format(
                par['k_shot_task']))

            if not par['load_from_checkpoint']:
                saver.save(sess, './checkpoints/{}'.format(save_fn))

            overall_task_accs = []
            for i in range(par['testing_iters']):

                saver.restore(sess, './checkpoints/{}'.format(save_fn))

                _, stim_in, y_hat, mk, _ = stim.generate_trial(
                    par['k_shot_task'])

                masking = np.zeros([par['batch_size'], 1])
                masking[:par['num_shots'], :] = 1
                feed_dict = {
                    x: stim_in,
                    y: y_hat,
                    g: par['gating'][0],
                    m: mk,
                    trial_mask: np.float32(masking)
                }

                for _ in range(par['shot_reps']):

                    # Run the model using one of the available stabilization methods
                    if par['stabilization'] == 'pathint':
                        _, _, loss, AL, spike_loss, output = sess.run([model.train_op, \
                            model.update_small_omega, model.pol_loss, model.aux_loss, \
                            model.spike_loss, model.output], feed_dict=feed_dict)
                    elif par['stabilization'] == 'EWC':
                        _, loss, AL, output = sess.run([model.train_op, model.pol_loss, \
                            model.aux_loss, model.output], feed_dict=feed_dict)

                task_accs = []
                task_grads = []
                task_states = []
                for t in range(par['n_tasks']):
                    _, stim_in, y_hat, mk, _ = stim.generate_trial(t)
                    output, batch_grads, h = sess.run(
                        [model.output, model.batch_grads, model.h],
                        feed_dict={
                            x: stim_in,
                            y: y_hat,
                            g: par['gating'][task],
                            m: mk
                        })
                    task_accs.append(get_perf(y_hat, output, mk))
                    task_grads.append(batch_grads)
                    task_states.append(h)

                print(
                    '\nTesting Iter {} | k-shot Trained Task {} Accuracy {:5.3f}'
                    .format(i, par['k_shot_task'],
                            task_accs[par['k_shot_task']]))
                print('Task accuracies:',
                      *['| {:5.3f}'.format(el) for el in task_accs])

                overall_task_accs.append(task_accs)

            overall_task_accs = np.mean(np.array(overall_task_accs), axis=0)
            pickle.dump(
                overall_task_accs,
                open('./savedir/post_kshot_accuracy_{}.pkl'.format(save_fn),
                     'wb'))
            print('\n-----\nOverall task accuracies:\n   ',
                  *['| {:5.3f}'.format(el) for el in overall_task_accs])

    print('\nModel execution for {} complete. (Supervised)'.format(save_fn))