def run_session(subject, condition, sim_params, output_file=None, plot=False): """ Run session in subject subject = subject object sim_params = simulation params output_file = if not none, writes h5 output to filename plot = plots session data if True """ print("** Condition: %s **" % condition) # Create session monitor session_monitor = SessionMonitor( subject.wta_network, sim_params, {}, record_connections=[], conv_window=40, record_firing_rates=True ) # Run on six coherence levels coherence_levels = [0.032, 0.064, 0.128, 0.256, 0.512] # Trials per coherence level trials_per_level = 20 # Create inputs for each trial trial_inputs = np.zeros((trials_per_level * len(coherence_levels), 2)) # Create left and right directions for each coherence level for i in range(len(coherence_levels)): coherence = coherence_levels[i] # Left min_idx = i * trials_per_level max_idx = i * trials_per_level + trials_per_level / 2 trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0 trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0 # Right min_idx = i * trials_per_level + trials_per_level / 2 max_idx = i * trials_per_level + trials_per_level trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0 trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0 # Shuffle trials trial_inputs = np.random.permutation(trial_inputs) # Simulate each trial for t in range(sim_params.ntrials): print("Trial %d" % t) # Get task input for trial and figure out which is correct task_input_rates = trial_inputs[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] # Run trial subject.run_trial(sim_params, task_input_rates) # subject.wta_monitor.plot() # plt.show() # Record trial session_monitor.record_trial(t, task_input_rates, correct_input, subject.wta_network, subject.wta_monitor) # Write output if output_file is not None: session_monitor.write_output(output_file) # Plot if plot: if sim_params.ntrials > 1: session_monitor.plot() else: subject.wta_monitor.plot() plt.show()
def run_session(subject, condition, sim_params, coherence_levels, output_file=None, plot=False): """ Run session in subject subject = subject object sim_params = simulation params coherence_levels = coherence levels to test on output_file = if not none, writes h5 output to filename plot = plots session data if True """ print('** Condition: %s **' % condition) # Record input connection weights record_connections = ['t0->e0_ampa', 't1->e1_ampa', 't0->e1_ampa', 't1->e0_ampa'] # Create session monitor session_monitor = SessionMonitor(subject.wta_network, sim_params, plasticity_params, record_connections=record_connections, conv_window=40, record_firing_rates=True) # Trials per coherence level trials_per_level = 20 # Create inputs for each trial trial_inputs = np.zeros((trials_per_level * len(coherence_levels), 2)) # Create left and right directions for each coherence level for i in range(len(coherence_levels)): coherence = coherence_levels[i] # Left min_idx=i*trials_per_level max_idx=i*trials_per_level+trials_per_level/2 trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0 trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0 #Right min_idx=i*trials_per_level+trials_per_level/2 max_idx=i*trials_per_level + trials_per_level trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0 trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0 trial_inputs_difficult = np.repeat(trial_inputs[0:40,:],3,0) trial_inputs_easy = np.repeat(trial_inputs[80:120,:],3,0) trial_inputs_1 = np.repeat(trial_inputs[0:20,:],6,0) trial_inputs_2 = np.repeat(trial_inputs[20:40,:],6,0) trial_inputs_3 = np.repeat(trial_inputs[40:60,:],6,0) trial_inputs_4 = np.repeat(trial_inputs[60:80,:],6,0) trial_inputs_5 = np.repeat(trial_inputs[80:100,:],6,0) trial_inputs_6 = np.repeat(trial_inputs[100:120,:],6,0) # Shuffle trials trial_inputs = np.random.permutation(trial_inputs) #trained on easy trial_inputs_easy = np.random.permutation(trial_inputs_easy) #trained on difficult trial_inputs_difficult = np.random.permutation(trial_inputs_difficult) trial_inputs_1 = np.random.permutation(trial_inputs_1) trial_inputs_2 = np.random.permutation(trial_inputs_2) trial_inputs_3 = np.random.permutation(trial_inputs_3) trial_inputs_4 = np.random.permutation(trial_inputs_4) trial_inputs_5 = np.random.permutation(trial_inputs_5) trial_inputs_6 = np.random.permutation(trial_inputs_6) # Simulate each trial for t in range(sim_params.ntrials): print('Trial %d' % t) if condition== 'training': if training == 'control': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == 'easy': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_easy[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == 'diff': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_difficult[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '1': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_1[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '2': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_2[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '3': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_3[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '4': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_4[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '5': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_5[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] elif training == '6': # Get task input for trial and figure out which is correct task_input_rates = trial_inputs_6[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] else: # Get task input for trial and figure out which is correct task_input_rates = trial_inputs[t, :] correct_input = np.where(task_input_rates == np.max(task_input_rates))[0] # Run trial subject.net.reinit(states=True) subject.run_trial(sim_params, task_input_rates) print task_input_rates # Record trial session_monitor.record_trial(t, task_input_rates, correct_input, subject.wta_network, subject.wta_monitor) # Write output if output_file is not None: session_monitor.write_output(output_file) # Plot if plot: session_monitor.plot() plt.show()