Beispiel #1
0
def run_session(subject, condition, sim_params, output_file=None, plot=False):
    """
    Run session in subject
    subject = subject object
    sim_params = simulation params
    output_file = if not none, writes h5 output to filename
    plot = plots session data if True
    """
    print("** Condition: %s **" % condition)

    # Create session monitor
    session_monitor = SessionMonitor(
        subject.wta_network, sim_params, {}, record_connections=[], conv_window=40, record_firing_rates=True
    )

    # Run on six coherence levels
    coherence_levels = [0.032, 0.064, 0.128, 0.256, 0.512]
    # Trials per coherence level
    trials_per_level = 20
    # Create inputs for each trial
    trial_inputs = np.zeros((trials_per_level * len(coherence_levels), 2))
    # Create left and right directions for each coherence level
    for i in range(len(coherence_levels)):
        coherence = coherence_levels[i]
        # Left
        min_idx = i * trials_per_level
        max_idx = i * trials_per_level + trials_per_level / 2
        trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0
        trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0

        # Right
        min_idx = i * trials_per_level + trials_per_level / 2
        max_idx = i * trials_per_level + trials_per_level
        trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0
        trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0

    # Shuffle trials
    trial_inputs = np.random.permutation(trial_inputs)

    # Simulate each trial
    for t in range(sim_params.ntrials):
        print("Trial %d" % t)

        # Get task input for trial and figure out which is correct
        task_input_rates = trial_inputs[t, :]
        correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]

        # Run trial
        subject.run_trial(sim_params, task_input_rates)

        # subject.wta_monitor.plot()
        # plt.show()

        # Record trial
        session_monitor.record_trial(t, task_input_rates, correct_input, subject.wta_network, subject.wta_monitor)

    # Write output
    if output_file is not None:
        session_monitor.write_output(output_file)

    # Plot
    if plot:
        if sim_params.ntrials > 1:
            session_monitor.plot()
        else:
            subject.wta_monitor.plot()
        plt.show()
Beispiel #2
0
def run_session(subject, condition, sim_params, coherence_levels, output_file=None, plot=False):
    """
    Run session in subject
    subject = subject object
    sim_params = simulation params
    coherence_levels = coherence levels to test on
    output_file = if not none, writes h5 output to filename
    plot = plots session data if True
    """
    print('** Condition: %s **' % condition)

    # Record input connection weights
    record_connections = ['t0->e0_ampa', 't1->e1_ampa', 't0->e1_ampa', 't1->e0_ampa']
    # Create session monitor
    session_monitor = SessionMonitor(subject.wta_network, sim_params, plasticity_params,
        record_connections=record_connections, conv_window=40, record_firing_rates=True)
    # Trials per coherence level
    trials_per_level = 20
    # Create inputs for each trial
    trial_inputs = np.zeros((trials_per_level * len(coherence_levels), 2))
    # Create left and right directions for each coherence level
    for i in range(len(coherence_levels)):
        coherence = coherence_levels[i]
        # Left
        min_idx=i*trials_per_level
        max_idx=i*trials_per_level+trials_per_level/2
        trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0
        trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0

        #Right
        min_idx=i*trials_per_level+trials_per_level/2
        max_idx=i*trials_per_level + trials_per_level
        trial_inputs[min_idx:max_idx, 0] = subject.wta_params.mu_0 - subject.wta_params.p_b * coherence * 100.0
        trial_inputs[min_idx:max_idx, 1] = subject.wta_params.mu_0 + subject.wta_params.p_a * coherence * 100.0

    trial_inputs_difficult = np.repeat(trial_inputs[0:40,:],3,0)
    trial_inputs_easy = np.repeat(trial_inputs[80:120,:],3,0)

    trial_inputs_1 = np.repeat(trial_inputs[0:20,:],6,0)
    trial_inputs_2 = np.repeat(trial_inputs[20:40,:],6,0)
    trial_inputs_3 = np.repeat(trial_inputs[40:60,:],6,0)
    trial_inputs_4 = np.repeat(trial_inputs[60:80,:],6,0)
    trial_inputs_5 = np.repeat(trial_inputs[80:100,:],6,0)
    trial_inputs_6 = np.repeat(trial_inputs[100:120,:],6,0)


    # Shuffle trials
    trial_inputs = np.random.permutation(trial_inputs)

    #trained on easy
    trial_inputs_easy = np.random.permutation(trial_inputs_easy)

    #trained on difficult
    trial_inputs_difficult = np.random.permutation(trial_inputs_difficult)

    trial_inputs_1 = np.random.permutation(trial_inputs_1)
    trial_inputs_2 = np.random.permutation(trial_inputs_2)
    trial_inputs_3 = np.random.permutation(trial_inputs_3)
    trial_inputs_4 = np.random.permutation(trial_inputs_4)
    trial_inputs_5 = np.random.permutation(trial_inputs_5)
    trial_inputs_6 = np.random.permutation(trial_inputs_6)


    # Simulate each trial
    for t in range(sim_params.ntrials):
        print('Trial %d' % t)

        if condition== 'training':
            if training == 'control':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == 'easy':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_easy[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == 'diff':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_difficult[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '1':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_1[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '2':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_2[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '3':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_3[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '4':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_4[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '5':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_5[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]
            elif training == '6':
                # Get task input for trial and figure out which is correct
                task_input_rates = trial_inputs_6[t, :]
                correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]

        else:
            # Get task input for trial and figure out which is correct
            task_input_rates = trial_inputs[t, :]
            correct_input = np.where(task_input_rates == np.max(task_input_rates))[0]



        # Run trial
        subject.net.reinit(states=True)
        subject.run_trial(sim_params, task_input_rates)
        print task_input_rates

        # Record trial
        session_monitor.record_trial(t, task_input_rates, correct_input, subject.wta_network, subject.wta_monitor)


    # Write output
    if output_file is not None:
        session_monitor.write_output(output_file)

    # Plot
    if plot:
        session_monitor.plot()
        plt.show()