def main(gpu_id=None): # Isolate requested GPU if gpu_id is not None: os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id # Initialize the model and stimulus model = Model() stim = Stimulus() for i in range(par['iterations']): # Generate a batch of trials trial_info = stim.make_batch() # Run the model on the provided batch of trials task_loss, act_loss, outputs = model.train_model(trial_info) # Calculate the network's accuracy acc_mask = trial_info['train_mask'] * (np.argmax( trial_info['desired_output'], axis=-1) != 0) accuracy = np.sum( acc_mask * (np.argmax(outputs['y'], axis=-1) == np.argmax( trial_info['desired_output'], axis=-1))) / np.sum(acc_mask) # Intermittently report feedback on the network if i % 10 == 0: # Plot the network's behavior behavior(trial_info, outputs, i) # Output the network's performance on the task print('{:>5} | Task Loss: {:6.3f} | Task Acc: {:5.3f} | Act. Loss: {:6.3f} |'.format(\ i, task_loss.numpy(), accuracy, act_loss.numpy()))
def main(): # Start the model run by loading the network controller and stimulus print('\nStarting model run: {}'.format(par['save_fn'])) control = NetworkController() stim = Stimulus() # Select whether to get losses ranked, according to learning method if par['learning_method'] in ['GA', 'TA']: is_ranked = True elif par['learning_method'] == 'ES': is_ranked = False else: raise Exception('Unknown learning method: {}'.format( par['learning_method'])) # Get loss baseline and update the ensemble reference accordingly control.load_batch(stim.make_batch()) control.run_models() control.judge_models() # loss_baseline only runs every 50 iterations, and then throws a nan warning. Which then causes h to have another axis and throw a memory loss error. loss_baseline = np.nanmean(control.get_losses(is_ranked)) print("loss_baseline") print(loss_baseline) control.update_constant('loss_baseline', loss_baseline) # Establish records for training loop save_record = {'iter':[], 'mean_task_acc':[], 'mean_full_acc':[], 'top_task_acc':[], \ 'top_full_acc':[], 'loss':[], 'mut_str':[], 'spiking':[], 'loss_factors':[]} t0 = time.time() # Run the training loop for i in range(par['iterations']): # Process a batch of stimulus using the current models control.load_batch(stim.make_batch()) control.run_models() control.judge_models() # Get the current loss scores loss = control.get_losses(is_ranked) # Apply optimizations based on the current learning method(s) mutation_strength = 0. if par['learning_method'] in ['GA', 'TA']: mutation_strength = par['mutation_strength'] * ( np.nanmean(loss[:par['num_survivors']]) / loss_baseline) control.update_constant('mutation_strength', mutation_strength) """ thresholds = [0.25, 0.1, 0.05, 0] modifiers = [1/2, 1/4, 1/8] for t in range(len(thresholds))[:-1]: if thresholds[t] > mutation_strength > thresholds[t+1]: mutation_strength = par['mutation_strength']*np.nanmean(loss)/loss_baseline * modifiers[t] break """ if par['learning_method'] == 'GA': control.breed_models_genetic() elif par['learning_method'] == 'TA': control.update_constant( 'temperature', par['temperature'] * par['temperature_decay']**i) control.breed_models_thermal(i) elif par['learning_method'] == 'ES': control.breed_models_evo_search(i) # Print and save network performance as desired if i % par['iters_per_output'] == 0: task_accuracy, full_accuracy = control.get_performance() loss_dict = control.get_losses_by_type(is_ranked) spikes = control.get_spiking() task_loss = np.mean(loss_dict['task'][:par['num_survivors']]) freq_loss = np.mean(loss_dict['freq'][:par['num_survivors']]) reci_loss = np.mean(loss_dict['reci'][:par['num_survivors']]) mean_loss = np.mean(loss[:par['num_survivors']]) task_acc = np.mean(task_accuracy[:par['num_survivors']]) full_acc = np.mean(full_accuracy[:par['num_survivors']]) spiking = np.mean(spikes[:par['num_survivors']]) if par['learning_method'] in ['GA', 'TA']: top_task_acc = task_accuracy.max() top_full_acc = full_accuracy.max() elif par['learning_method'] == 'ES': top_task_acc = task_accuracy[0] top_full_acc = full_accuracy[0] save_record['iter'].append(i) save_record['top_task_acc'].append(top_task_acc) save_record['top_full_acc'].append(top_full_acc) save_record['mean_task_acc'].append(task_acc) save_record['mean_full_acc'].append(full_acc) save_record['loss'].append(mean_loss) save_record['loss_factors'].append(loss_dict) save_record['mut_str'].append(mutation_strength) save_record['spiking'].append(spiking) pickle.dump(save_record, open(par['save_dir'] + par['save_fn'] + '.pkl', 'wb')) if i % (10 * par['iters_per_output']) == 0: print('Saving weights for iteration {}... ({})\n'.format( i, par['save_fn'])) pickle.dump( to_cpu(control.var_dict), open(par['save_dir'] + par['save_fn'] + '_weights.pkl', 'wb')) status_stringA = 'Iter: {:4} | Task Loss: {:5.3f} | Freq Loss: {:5.3f} | Reci Loss: {:5.3f}'.format( \ i, task_loss, freq_loss, reci_loss) status_stringB = 'Opt: {:>4} | Full Loss: {:5.3f} | Mut Str: {:7.5f} | Spiking: {:5.2f} Hz'.format( \ par['learning_method'], mean_loss, mutation_strength, spiking) status_stringC = 'S/O: {:4} | Top Acc (Task/Full): {:5.3f} / {:5.3f} | Mean Acc (Task/Full): {:5.3f} / {:5.3f}'.format( \ int(time.time()-t0), top_task_acc, top_full_acc, task_acc, full_acc) print(status_stringA + '\n' + status_stringB + '\n' + status_stringC + '\n') t0 = time.time()
# Make a new model and stimulus (which use the loaded parameters) print('\nLoading and running model.') model = Model() stim = Stimulus() runs = 8 c_all = [] d_all = [] v_all = [] s_all = [] # Run a couple batches to generate sufficient data points for i in range(runs): print('R:{:>2}'.format(i), end='\r') trial_info = stim.make_batch(var_delay=False) model.run_model(trial_info, testing=True) c_all.append(trial_info['sample_cat']) d_all.append(trial_info['sample_dir']) v_all.append(to_cpu(model.v)) s_all.append(to_cpu(model.s)) del model del stim batch_size = runs * par['batch_size'] c = np.concatenate(c_all, axis=0) d = np.concatenate(d_all, axis=0) v = np.concatenate(v_all, axis=1)
def run_SVM_analysis(): print('\nLoading and running model.') model = Model() stim = Stimulus() runs = 8 m_all = [] v_all = [] s_all = [] for i in range(runs): print('R:{:>2}'.format(i), end='\r') trial_info = stim.make_batch(var_delay=False) model.run_model(trial_info) m_all.append(trial_info['sample_cat']) v_all.append(to_cpu(model.v)) s_all.append(to_cpu(model.s)) del model del stim batch_size = runs * par['batch_size'] m = np.concatenate(m_all, axis=0) v = np.concatenate(v_all, axis=1) s = np.concatenate(s_all, axis=1) print('Performing SVM decoding on {} trials.\n'.format(batch_size)) # Initialize linear classifier args = { 'kernel': 'linear', 'decision_function_shape': 'ovr', 'shrinking': False, 'tol': 1e-3 } lin_clf_v = SVC(**args) lin_clf_s = SVC(**args) score_v = np.zeros([par['num_time_steps']]) score_s = np.zeros([par['num_time_steps']]) # Choose training and testing indices train_pct = 0.75 num_train_inds = int(batch_size * train_pct) shuffled = np.random.permutation(batch_size) train_inds = shuffled[:num_train_inds] test_inds = shuffled[num_train_inds:] for t in range(end_dead_time, par['num_time_steps']): print('T:{:>4}'.format(t), end='\r') lin_clf_v.fit(v[t, train_inds, :], m[train_inds]) lin_clf_s.fit(s[t, train_inds, :], m[train_inds]) dec_v = lin_clf_v.predict(v[t, test_inds, :]) dec_s = lin_clf_s.predict(s[t, test_inds, :]) score_v[t] = np.mean(m[test_inds] == dec_v) score_s[t] = np.mean(m[test_inds] == dec_s) fig, ax = plt.subplots(1, figsize=(12, 8)) ax.plot(score_v, c=[241 / 255, 153 / 255, 1 / 255], label='Voltage') ax.plot(score_s, c=[58 / 255, 79 / 255, 65 / 255], label='Syn. Eff.') ax.axhline(0.5, c='k', ls='--') ax.axvline(trial_info['timings'][0, 0], c='k', ls='--') ax.axvline(trial_info['timings'][1, 0], c='k', ls='--') ax.set_title('SVM Decoding of Sample Category') ax.set_xlabel('Time') ax.set_ylabel('Decoding Accuracy') ax.set_yticks([0., 0.25, 0.5, 0.75, 1.]) ax.grid() ax.set_xlim(0, par['num_time_steps'] - 1) ax.legend() plt.savefig('./analysis/svm_decoding.png', bbox_inches='tight') print('SVM decoding complete.')
def main(): # Start the model run by loading the network controller and stimulus print('\nLoading model...') model = Model() stim = Stimulus() t0 = time.time() print('Starting training.\n') full_acc_record = [] task_acc_record = [] iter_record = [] I_sqr_record = [] W_rnn_grad_sum_record = [] W_rnn_grad_norm_record = [] # Run the training loop for i in range(par['iterations']): # Process a batch of stimulus using the current models trial_info = stim.make_batch() model.run_model(trial_info) model.optimize() losses = model.get_losses() mean_spiking = model.get_mean_spiking() task_accuracy, full_accuracy = model.get_performance() full_acc_record.append(full_accuracy) task_acc_record.append(task_accuracy) iter_record.append(i) I_sqr_record.append(model.I_sqr) W_rnn_grad_sum_record.append(cp.sum(model.var_dict['W_rnn'])) W_rnn_grad_norm_record.append(LA.norm(model.grad_dict['W_rnn'])) W_exc_mean = cp.mean( cp.maximum(0, model.var_dict['W_rnn'][:par['n_exc'], :])) W_inh_mean = cp.mean( cp.maximum(0, model.var_dict['W_rnn'][par['n_exc']:, :])) info_str0 = 'Iter {:>5} | Task Loss: {:5.3f} | Task Acc: {:5.3f} | '.format( i, losses['task'], task_accuracy) info_str1 = 'Full Acc: {:5.3f} | Mean Spiking: {:6.3f} Hz'.format( full_accuracy, mean_spiking) print('Aggregating data...', end='\r') if i % 20 == 0: # print('Mean EXC w_rnn ', W_exc_mean, 'mean INH w_rnn', W_inh_mean) if par['plot_EI_testing']: pf.EI_testing_plots(i, I_sqr_record, W_rnn_grad_sum_record, W_rnn_grad_norm_record) pf.run_pev_analysis(trial_info['sample'], to_cpu(model.su*model.sx), \ to_cpu(model.z), to_cpu(cp.stack(I_sqr_record)), i) weights = to_cpu(model.var_dict['W_rnn']) fn = './savedir/{}_weights.pkl'.format(par['savefn']) data = {'weights': weights, 'par': par} pickle.dump(data, open(fn, 'wb')) pf.activity_plots(i, model) pf.clopath_update_plot(i, model.clopath_W_in, model.clopath_W_rnn, \ model.grad_dict['W_in'], model.grad_dict['W_rnn']) pf.plot_grads_and_epsilons(i, trial_info, model, model.h, model.eps_v_rec, model.eps_w_rec, model.eps_ir_rec) if i != 0: pf.training_curve(i, iter_record, full_acc_record, task_acc_record) if i % 100 == 0: model.visualize_delta(i) if par['save_data_files']: data = {'par': par, 'weights': to_cpu(model.var_dict)} pickle.dump( data, open( './savedir/{}_data_iter{:0>6}.pkl'.format( par['savefn'], i), 'wb')) trial_info = stim.make_batch(var_delay=False) model.run_model(trial_info, testing=True) model.show_output_behavior(i, trial_info) # Print output info (after all saving of data is complete) print(info_str0 + info_str1) if i % 100 == 0: if np.mean(task_acc_record[-100:]) > 0.9: print( '\nMean accuracy greater than 0.9 over last 100 iters.\nMoving on to next model.\n' ) break