示例#1
0
def main(gpu_id=None):

    # Isolate requested GPU
    if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id

    # Initialize the model and stimulus
    model = Model()
    stim = Stimulus()

    for i in range(par['iterations']):

        # Generate a batch of trials
        trial_info = stim.make_batch()

        # Run the model on the provided batch of trials
        task_loss, act_loss, outputs = model.train_model(trial_info)

        # Calculate the network's accuracy
        acc_mask = trial_info['train_mask'] * (np.argmax(
            trial_info['desired_output'], axis=-1) != 0)
        accuracy = np.sum(
            acc_mask * (np.argmax(outputs['y'], axis=-1) == np.argmax(
                trial_info['desired_output'], axis=-1))) / np.sum(acc_mask)

        # Intermittently report feedback on the network
        if i % 10 == 0:

            # Plot the network's behavior
            behavior(trial_info, outputs, i)

            # Output the network's performance on the task
            print('{:>5} | Task Loss: {:6.3f} | Task Acc: {:5.3f} | Act. Loss: {:6.3f} |'.format(\
             i, task_loss.numpy(), accuracy, act_loss.numpy()))
示例#2
0
def main():

    # Start the model run by loading the network controller and stimulus
    print('\nStarting model run: {}'.format(par['save_fn']))
    control = NetworkController()
    stim = Stimulus()

    # Select whether to get losses ranked, according to learning method
    if par['learning_method'] in ['GA', 'TA']:
        is_ranked = True
    elif par['learning_method'] == 'ES':
        is_ranked = False
    else:
        raise Exception('Unknown learning method: {}'.format(
            par['learning_method']))

    # Get loss baseline and update the ensemble reference accordingly
    control.load_batch(stim.make_batch())
    control.run_models()
    control.judge_models()
    # loss_baseline only runs every 50 iterations, and then throws a nan warning. Which then causes h to have another axis and throw a memory loss error.
    loss_baseline = np.nanmean(control.get_losses(is_ranked))
    print("loss_baseline")
    print(loss_baseline)
    control.update_constant('loss_baseline', loss_baseline)

    # Establish records for training loop
    save_record = {'iter':[], 'mean_task_acc':[], 'mean_full_acc':[], 'top_task_acc':[], \
        'top_full_acc':[], 'loss':[], 'mut_str':[], 'spiking':[], 'loss_factors':[]}

    t0 = time.time()
    # Run the training loop
    for i in range(par['iterations']):

        # Process a batch of stimulus using the current models
        control.load_batch(stim.make_batch())
        control.run_models()
        control.judge_models()

        # Get the current loss scores
        loss = control.get_losses(is_ranked)

        # Apply optimizations based on the current learning method(s)
        mutation_strength = 0.
        if par['learning_method'] in ['GA', 'TA']:
            mutation_strength = par['mutation_strength'] * (
                np.nanmean(loss[:par['num_survivors']]) / loss_baseline)
            control.update_constant('mutation_strength', mutation_strength)
            """
            thresholds = [0.25, 0.1, 0.05, 0]
            modifiers  = [1/2, 1/4, 1/8]
            for t in range(len(thresholds))[:-1]:
                if thresholds[t] > mutation_strength > thresholds[t+1]:
                    mutation_strength = par['mutation_strength']*np.nanmean(loss)/loss_baseline * modifiers[t]
                    break
            """

            if par['learning_method'] == 'GA':
                control.breed_models_genetic()
            elif par['learning_method'] == 'TA':
                control.update_constant(
                    'temperature',
                    par['temperature'] * par['temperature_decay']**i)
                control.breed_models_thermal(i)

        elif par['learning_method'] == 'ES':
            control.breed_models_evo_search(i)

        # Print and save network performance as desired
        if i % par['iters_per_output'] == 0:
            task_accuracy, full_accuracy = control.get_performance()
            loss_dict = control.get_losses_by_type(is_ranked)
            spikes = control.get_spiking()

            task_loss = np.mean(loss_dict['task'][:par['num_survivors']])
            freq_loss = np.mean(loss_dict['freq'][:par['num_survivors']])
            reci_loss = np.mean(loss_dict['reci'][:par['num_survivors']])

            mean_loss = np.mean(loss[:par['num_survivors']])
            task_acc = np.mean(task_accuracy[:par['num_survivors']])
            full_acc = np.mean(full_accuracy[:par['num_survivors']])
            spiking = np.mean(spikes[:par['num_survivors']])

            if par['learning_method'] in ['GA', 'TA']:
                top_task_acc = task_accuracy.max()
                top_full_acc = full_accuracy.max()
            elif par['learning_method'] == 'ES':
                top_task_acc = task_accuracy[0]
                top_full_acc = full_accuracy[0]

            save_record['iter'].append(i)
            save_record['top_task_acc'].append(top_task_acc)
            save_record['top_full_acc'].append(top_full_acc)
            save_record['mean_task_acc'].append(task_acc)
            save_record['mean_full_acc'].append(full_acc)
            save_record['loss'].append(mean_loss)
            save_record['loss_factors'].append(loss_dict)
            save_record['mut_str'].append(mutation_strength)
            save_record['spiking'].append(spiking)
            pickle.dump(save_record,
                        open(par['save_dir'] + par['save_fn'] + '.pkl', 'wb'))
            if i % (10 * par['iters_per_output']) == 0:
                print('Saving weights for iteration {}... ({})\n'.format(
                    i, par['save_fn']))
                pickle.dump(
                    to_cpu(control.var_dict),
                    open(par['save_dir'] + par['save_fn'] + '_weights.pkl',
                         'wb'))

            status_stringA = 'Iter: {:4} | Task Loss: {:5.3f} | Freq Loss: {:5.3f} | Reci Loss: {:5.3f}'.format( \
                i, task_loss, freq_loss, reci_loss)
            status_stringB = 'Opt:  {:>4} | Full Loss: {:5.3f} | Mut Str: {:7.5f} | Spiking: {:5.2f} Hz'.format( \
                par['learning_method'], mean_loss, mutation_strength, spiking)
            status_stringC = 'S/O:  {:4} | Top Acc (Task/Full): {:5.3f} / {:5.3f}  | Mean Acc (Task/Full): {:5.3f} / {:5.3f}'.format( \
                int(time.time()-t0), top_task_acc, top_full_acc, task_acc, full_acc)
            print(status_stringA + '\n' + status_stringB + '\n' +
                  status_stringC + '\n')
            t0 = time.time()
示例#3
0
    # Make a new model and stimulus (which use the loaded parameters)
    print('\nLoading and running model.')
    model = Model()
    stim = Stimulus()
    runs = 8

    c_all = []
    d_all = []
    v_all = []
    s_all = []

    # Run a couple batches to generate sufficient data points
    for i in range(runs):
        print('R:{:>2}'.format(i), end='\r')
        trial_info = stim.make_batch(var_delay=False)
        model.run_model(trial_info, testing=True)

        c_all.append(trial_info['sample_cat'])
        d_all.append(trial_info['sample_dir'])
        v_all.append(to_cpu(model.v))
        s_all.append(to_cpu(model.s))

    del model
    del stim

    batch_size = runs * par['batch_size']

    c = np.concatenate(c_all, axis=0)
    d = np.concatenate(d_all, axis=0)
    v = np.concatenate(v_all, axis=1)
示例#4
0
def run_SVM_analysis():

    print('\nLoading and running model.')
    model = Model()
    stim = Stimulus()
    runs = 8

    m_all = []
    v_all = []
    s_all = []

    for i in range(runs):
        print('R:{:>2}'.format(i), end='\r')
        trial_info = stim.make_batch(var_delay=False)
        model.run_model(trial_info)

        m_all.append(trial_info['sample_cat'])
        v_all.append(to_cpu(model.v))
        s_all.append(to_cpu(model.s))

    del model
    del stim

    batch_size = runs * par['batch_size']

    m = np.concatenate(m_all, axis=0)
    v = np.concatenate(v_all, axis=1)
    s = np.concatenate(s_all, axis=1)

    print('Performing SVM decoding on {} trials.\n'.format(batch_size))
    # Initialize linear classifier
    args = {
        'kernel': 'linear',
        'decision_function_shape': 'ovr',
        'shrinking': False,
        'tol': 1e-3
    }
    lin_clf_v = SVC(**args)
    lin_clf_s = SVC(**args)

    score_v = np.zeros([par['num_time_steps']])
    score_s = np.zeros([par['num_time_steps']])

    # Choose training and testing indices
    train_pct = 0.75
    num_train_inds = int(batch_size * train_pct)

    shuffled = np.random.permutation(batch_size)
    train_inds = shuffled[:num_train_inds]
    test_inds = shuffled[num_train_inds:]

    for t in range(end_dead_time, par['num_time_steps']):
        print('T:{:>4}'.format(t), end='\r')

        lin_clf_v.fit(v[t, train_inds, :], m[train_inds])
        lin_clf_s.fit(s[t, train_inds, :], m[train_inds])

        dec_v = lin_clf_v.predict(v[t, test_inds, :])
        dec_s = lin_clf_s.predict(s[t, test_inds, :])

        score_v[t] = np.mean(m[test_inds] == dec_v)
        score_s[t] = np.mean(m[test_inds] == dec_s)

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.plot(score_v, c=[241 / 255, 153 / 255, 1 / 255], label='Voltage')
    ax.plot(score_s, c=[58 / 255, 79 / 255, 65 / 255], label='Syn. Eff.')

    ax.axhline(0.5, c='k', ls='--')
    ax.axvline(trial_info['timings'][0, 0], c='k', ls='--')
    ax.axvline(trial_info['timings'][1, 0], c='k', ls='--')

    ax.set_title('SVM Decoding of Sample Category')
    ax.set_xlabel('Time')
    ax.set_ylabel('Decoding Accuracy')
    ax.set_yticks([0., 0.25, 0.5, 0.75, 1.])
    ax.grid()
    ax.set_xlim(0, par['num_time_steps'] - 1)

    ax.legend()
    plt.savefig('./analysis/svm_decoding.png', bbox_inches='tight')

    print('SVM decoding complete.')
示例#5
0
def main():

    # Start the model run by loading the network controller and stimulus
    print('\nLoading model...')
    model = Model()
    stim = Stimulus()

    t0 = time.time()
    print('Starting training.\n')

    full_acc_record = []
    task_acc_record = []
    iter_record = []
    I_sqr_record = []
    W_rnn_grad_sum_record = []
    W_rnn_grad_norm_record = []

    # Run the training loop
    for i in range(par['iterations']):

        # Process a batch of stimulus using the current models
        trial_info = stim.make_batch()
        model.run_model(trial_info)
        model.optimize()

        losses = model.get_losses()
        mean_spiking = model.get_mean_spiking()
        task_accuracy, full_accuracy = model.get_performance()

        full_acc_record.append(full_accuracy)
        task_acc_record.append(task_accuracy)
        iter_record.append(i)
        I_sqr_record.append(model.I_sqr)
        W_rnn_grad_sum_record.append(cp.sum(model.var_dict['W_rnn']))
        W_rnn_grad_norm_record.append(LA.norm(model.grad_dict['W_rnn']))

        W_exc_mean = cp.mean(
            cp.maximum(0, model.var_dict['W_rnn'][:par['n_exc'], :]))
        W_inh_mean = cp.mean(
            cp.maximum(0, model.var_dict['W_rnn'][par['n_exc']:, :]))

        info_str0 = 'Iter {:>5} | Task Loss: {:5.3f} | Task Acc: {:5.3f} | '.format(
            i, losses['task'], task_accuracy)
        info_str1 = 'Full Acc: {:5.3f} | Mean Spiking: {:6.3f} Hz'.format(
            full_accuracy, mean_spiking)
        print('Aggregating data...', end='\r')

        if i % 20 == 0:

            # print('Mean EXC w_rnn ', W_exc_mean, 'mean INH w_rnn', W_inh_mean)
            if par['plot_EI_testing']:
                pf.EI_testing_plots(i, I_sqr_record, W_rnn_grad_sum_record,
                                    W_rnn_grad_norm_record)

            pf.run_pev_analysis(trial_info['sample'], to_cpu(model.su*model.sx), \
             to_cpu(model.z), to_cpu(cp.stack(I_sqr_record)), i)
            weights = to_cpu(model.var_dict['W_rnn'])
            fn = './savedir/{}_weights.pkl'.format(par['savefn'])
            data = {'weights': weights, 'par': par}
            pickle.dump(data, open(fn, 'wb'))

            pf.activity_plots(i, model)
            pf.clopath_update_plot(i, model.clopath_W_in, model.clopath_W_rnn, \
             model.grad_dict['W_in'], model.grad_dict['W_rnn'])
            pf.plot_grads_and_epsilons(i, trial_info, model, model.h,
                                       model.eps_v_rec, model.eps_w_rec,
                                       model.eps_ir_rec)

            if i != 0:
                pf.training_curve(i, iter_record, full_acc_record,
                                  task_acc_record)

            if i % 100 == 0:
                model.visualize_delta(i)

                if par['save_data_files']:
                    data = {'par': par, 'weights': to_cpu(model.var_dict)}
                    pickle.dump(
                        data,
                        open(
                            './savedir/{}_data_iter{:0>6}.pkl'.format(
                                par['savefn'], i), 'wb'))

            trial_info = stim.make_batch(var_delay=False)
            model.run_model(trial_info, testing=True)
            model.show_output_behavior(i, trial_info)

        # Print output info (after all saving of data is complete)
        print(info_str0 + info_str1)

        if i % 100 == 0:
            if np.mean(task_acc_record[-100:]) > 0.9:
                print(
                    '\nMean accuracy greater than 0.9 over last 100 iters.\nMoving on to next model.\n'
                )
                break