コード例 #1
0
  def test_while_single_carry(self):
    """A while with a single carry"""
    def func(x):
      # Equivalent to:
      #      for(i=x; i < 4; i++);
      return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x)

    self.ConvertAndCompare(func, jnp.int_(0))
コード例 #2
0
ファイル: spd.py プロジェクト: jschiavon/optispd
 def __init__(self, p, m=1, approx=True):
     """Manifold of (p x p) symmetric positive definite matrix."""
     assert isinstance(p, (int, jnp.integer)), "p must be an integer"
     assert isinstance(m, (int, jnp.integer)), "m must be an integer"
     self._p = p
     self._m = m
     if m == 1:
         name = "({0} x {0}) SPD ".format(p)
     else:
         name = "Product {1} ({0} x {0}) SPDs".format(p, m)
     self._dimension = m * jnp.int_(p * (p + 1) / 2)
     self._name = name
     self._approximated = approx
コード例 #3
0
ファイル: rnn_tasks_g.py プロジェクト: dashirn/FAST_RNN_JAX
def build_tasks(batch_size,
                input_params,
                task_select,
                rand_generator_idx,
                do_plot=False,
                do_save=False):
    """
    Task Builder.
    
    Builds tasks: Delay Pro, Delay Anti, Mem Pro, Mem Anti, MemDm1, MemDm2, ContextMemDm1, ContextMemDm2, MultiMem.
    
    Also builds task primitives for building modular RNN: 
        
        (1) mem_module: build memory modules, as in paper
        (2) build integration modules (experimental)
        
    In:
        - batch_size: size of the batch of task examples to generate
        - input_params: a set of input parameters that define task length, noise std dev, etc.
        - task_select: indicates which task to generate. Possible options are 
        delay_pro, delay_anti, mem_pro, mem_anti, mem_dm1, mem_dm2, 
        context_mem_dm1, context_mem_dm2, multi_mem, mem_module, integration_module
        - rand_generator_idx: seed for the random num generator, new for each batch (Jax)
        - do_plot: whether to plot some examples (default:False)
        - do_save: whether to save example plots (default:False)
    
    Out: 
        - Inputs: Time x Trials x Num Inputs
        - Targets: Time x Trials x Num Outputs
    """
    #Generate a particular new random key for a batch to get fresh batch
    batch_key = random.PRNGKey(rand_generator_idx)
    subkeys = []
    for k in range(15):
        batch_key, subkey = random.split(batch_key)
        subkeys.append(subkey)

    _, bias_val, stddev_val, T, ntime, save_name = input_params
    dt = T / float(ntime)
    stddev = stddev_val / np.sqrt(dt)
    zeros_beginning = 10
    ntrials = batch_size

    def create_angle_vecs(angles, rkey):
        angle_rand_idx = random.randint(rkey, (ntrials, ), 0, np.size(angles))
        cos_angles = np.array(
            [np.cos(np.radians(angles[i])) for i in angle_rand_idx])
        sin_angles = np.array(
            [np.sin(np.radians(angles[i])) for i in angle_rand_idx])
        cos_angles = np.expand_dims(cos_angles, axis=0)
        cos_angles = np.expand_dims(cos_angles, axis=2)
        sin_angles = np.expand_dims(sin_angles, axis=0)
        sin_angles = np.expand_dims(sin_angles, axis=2)
        return cos_angles, sin_angles

    angles = np.arange(0, 360, 10)
    cos_angles, sin_angles = create_angle_vecs(angles, subkeys[0])

    #Opposite angles. Use same random key to create matching angles
    angles_opposite = angles + 180
    cos_angles_opposite, sin_angles_opposite = create_angle_vecs(
        angles_opposite, subkeys[0])

    def draw_noise_input(ntime, zeros_beginning, ntrials, num_inputs, stddev,
                         rkey):
        noise_input = stddev * random.normal(rkey,
                                             (ntime, ntrials, num_inputs))
        noise_input_pluszero = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, num_inputs)),
             noise_input))  #Allow context to establish
        return noise_input_pluszero

    ##### FIXATION INPUT
    input_f = np.concatenate((np.zeros((zeros_beginning, ntrials, 1)),
                              np.ones((np.int_(ntime * 3 / 4), ntrials, 1)),
                              np.zeros(
                                  (np.int_(ntime * (1 / 4)), ntrials, 1))),
                             axis=0)

    inputs_f_plusnoise = input_f + draw_noise_input(
        ntime, zeros_beginning, ntrials, 1, stddev, subkeys[1])

    ##### RULE INPUT
    input_r_zero = np.concatenate((np.zeros(
        (zeros_beginning, ntrials, 1)), np.zeros((ntime, ntrials, 1))),
                                  axis=0)
    input_r_zero_plusnoise = input_r_zero + draw_noise_input(
        ntime, zeros_beginning, ntrials, 1, stddev, subkeys[1])
    input_r_one = np.concatenate((np.zeros(
        (zeros_beginning, ntrials, 1)), np.ones((ntime, ntrials, 1))),
                                 axis=0)
    input_r_one_plusnoise = input_r_one + draw_noise_input(
        ntime, zeros_beginning, ntrials, 1, stddev, subkeys[1])
    inputs_r = (input_r_zero_plusnoise, input_r_one_plusnoise)

    if task_select == 'delay_pro' or task_select == 'delay_anti' or task_select == 'mem_pro' or task_select == 'mem_anti':

        ##### STIMULUS INPUT
        input_s_cos = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(np.repeat(cos_angles, np.int_(ntime * 3 / 4), axis=0),
                         np.ones((np.int_(ntime * 3 / 4), ntrials, 1)))),
            axis=0)
        input_s_cos_plusnoise = input_s_cos + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[2])  #

        input_s_sin = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(np.repeat(sin_angles, np.int_(ntime * 3 / 4), axis=0),
                         np.ones((np.int_(ntime * 3 / 4), ntrials, 1)))),
            axis=0)
        input_s_sin_plusnoise = input_s_sin + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[3])  #3

        inputs_s = (input_s_cos_plusnoise, input_s_sin_plusnoise)

        ##### MEMORY INPUT
        duration_val_stim = np.int_(ntime * 2 / 4)
        duration_val_delay = np.int_(ntime * 0)

        input_s_cos_mem = np.concatenate(
            (
                np.zeros((zeros_beginning, ntrials, 1)),
                np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
                np.multiply(
                    np.repeat(cos_angles, duration_val_stim,
                              axis=0),  #np.int_(ntime*2/4)
                    np.ones((duration_val_stim, ntrials, 1))),
                np.zeros((duration_val_delay, ntrials, 1)),
                np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_cos_mem_plusnoise = input_s_cos_mem + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[4])  #2

        input_s_sin_mem = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(np.repeat(sin_angles, duration_val_stim, axis=0),
                         np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_delay, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_sin_mem_plusnoise = input_s_sin_mem + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[5])

        inputs_s_mem = (input_s_cos_mem_plusnoise, input_s_sin_mem_plusnoise)

        ###### GENERAL TASK TARGETS
        target_cos = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(3 / 4 * ntime), ntrials, 1)),
             np.multiply(np.repeat(cos_angles, np.int_(1 / 4 * ntime), axis=0),
                         np.ones((np.int_(1 / 4 * ntime), ntrials, 1)))),
            axis=0)
        target_sin = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(3 / 4 * ntime), ntrials, 1)),
             np.multiply(np.repeat(sin_angles, np.int_(1 / 4 * ntime), axis=0),
                         np.ones((np.int_(1 / 4 * ntime), ntrials, 1)))),
            axis=0)
        targets_s = (target_cos, target_sin)

        #Opposite direction
        target_cos_opposite = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(3 / 4 * ntime), ntrials, 1)),
             np.multiply(
                 np.repeat(cos_angles_opposite, np.int_(1 / 4 * ntime),
                           axis=0),
                 np.ones((np.int_(1 / 4 * ntime), ntrials, 1)))),
            axis=0)
        target_sin_opposite = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(3 / 4 * ntime), ntrials, 1)),
             np.multiply(
                 np.repeat(sin_angles_opposite, np.int_(1 / 4 * ntime),
                           axis=0),
                 np.ones((np.int_(1 / 4 * ntime), ntrials, 1)))),
            axis=0)
        targets_opposite = (target_cos_opposite, target_sin_opposite)

    elif task_select == 'mem_module':

        #Using plus and minus modules
        cos_angles_positive, _ = create_angle_vecs(np.arange(0, 90, 5),
                                                   subkeys[0])
        cos_angles_negative, _ = create_angle_vecs(np.arange(90, 180, 5),
                                                   subkeys[0])

        #Draw a particular delay duration for current batch
        # duration_val_stim = np.int_(10)#np.int_(ntime*random.uniform(subkeys[1],(1,),minval=1/8,maxval=4/8)[0])
        # duration_val_delay = ntime - duration_val_stim - 2*np.int_(ntime*1/4)
        duration_val_stim = np.int_(5)
        duration_val_delay = np.int_(45)

        ##### MEMORY INPUT
        input_s_cos_mem_short = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(
                 np.repeat(cos_angles_positive, duration_val_stim, axis=0),
                 np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_delay, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_cos_mem_short_plusnoise = input_s_cos_mem_short + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[6])  #6

        input_s_sin_mem_short = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(
                 np.repeat(cos_angles_negative, duration_val_stim, axis=0),
                 np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_delay, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_sin_mem_short_plusnoise = input_s_sin_mem_short + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[7])  #7

        inputs_s_mem_short = (input_s_cos_mem_short_plusnoise,
                              input_s_sin_mem_short_plusnoise)

        ###### GENERAL TASK TARGETS
        target_hold_time = duration_val_delay + np.int_(ntime * 1 / 4)
        target_s_cos_mem_long = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.zeros((duration_val_stim, ntrials, 1)),
             np.multiply(
                 np.repeat(cos_angles_positive, target_hold_time, axis=0),
                 np.ones((target_hold_time, ntrials, 1)))),
            axis=0)
        target_s_sin_mem_long = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.zeros((duration_val_stim, ntrials, 1)),
             np.multiply(
                 np.repeat(cos_angles_negative, target_hold_time, axis=0),
                 np.ones((target_hold_time, ntrials, 1)))),
            axis=0)
        targets_s_mem_long = (target_s_cos_mem_long, target_s_sin_mem_long)

    elif task_select == 'integration_module':
        nwninputs = 1
        # biases_1xexw = np.expand_dims(bias_val * 2.0 * (npo.random.rand(ntrials,nwninputs) -0.5), axis=0)
        biases_1xexw_pos = np.expand_dims(
            bias_val * 2.0 * (npo.random.rand(ntrials, nwninputs) + 1), axis=0)
        biases_1xexw_neg = np.expand_dims(
            bias_val * 2.0 * (npo.random.rand(ntrials, nwninputs) - 1), axis=0)
        stddev = stddev_val / np.sqrt(dt)
        noise_txexw = stddev * npo.random.randn(ntime, ntrials, nwninputs)
        white_noise_txexw_pos = biases_1xexw_pos + noise_txexw
        white_noise_txexw_neg = biases_1xexw_neg + noise_txexw

        # white_noise_txexw =  np.concatenate((np.zeros((zeros_beginning,ntrials,nwninputs)),
        #                                     white_noise_txexw),axis=0)
        white_noise_txexw_pos = np.concatenate((np.zeros(
            (zeros_beginning, ntrials, nwninputs)), white_noise_txexw_pos),
                                               axis=0)
        white_noise_txexw_neg = np.concatenate((np.zeros(
            (zeros_beginning, ntrials, nwninputs)), white_noise_txexw_neg),
                                               axis=0)

        # Create the desired outputs
        inputs_int = (white_noise_txexw_pos, white_noise_txexw_neg)
        targets_int = (np.cumsum(white_noise_txexw_pos, axis=0),
                       np.cumsum(white_noise_txexw_neg, axis=0))

    elif task_select == 'mem_dm1' or task_select == 'mem_dm2' or task_select == 'context_mem_dm1' or task_select == 'context_mem_dm2' or task_select == 'multi_mem':

        cos_angles2, sin_angles2 = create_angle_vecs(angles, subkeys[10])
        target_vals_cos = np.where(cos_angles2 >= cos_angles, cos_angles2,
                                   cos_angles)
        target_vals_sin = np.where(sin_angles2 >= sin_angles, sin_angles2,
                                   sin_angles)

        #Draw a particular delay duration for current batch
        duration_val_stim = np.int_(10)  #
        duration_val_interstim = 2 * duration_val_stim
        duration_val_delay = ntime - 2 * duration_val_stim - duration_val_interstim - 2 * np.int_(
            ntime * 1 / 4)

        ##### MEMORY INPUT
        input_s_cos_mem_short = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(np.repeat(cos_angles, duration_val_stim, axis=0),
                         np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_interstim, ntrials, 1)),
             np.multiply(np.repeat(cos_angles2, duration_val_stim, axis=0),
                         np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_delay, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_cos_mem_short_plusnoise = input_s_cos_mem_short + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[8])  #8

        input_s_sin_mem_short = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.multiply(np.repeat(sin_angles, duration_val_stim, axis=0),
                         np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_interstim, ntrials, 1)),
             np.multiply(np.repeat(sin_angles2, duration_val_stim, axis=0),
                         np.ones((duration_val_stim, ntrials, 1))),
             np.zeros((duration_val_delay, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1))),
            axis=0)
        input_s_sin_mem_short_plusnoise = input_s_sin_mem_short + draw_noise_input(
            ntime, zeros_beginning, ntrials, 1, stddev, subkeys[9])  #9

        inputs_s_mem_dm = (input_s_cos_mem_short_plusnoise,
                           input_s_sin_mem_short_plusnoise)

        ###### GENERAL TASK TARGETS
        target_zeros = ntime - 2 * np.int_(ntime * 1 / 4)
        target_s_cos_mem_long = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.zeros((target_zeros, ntrials, 1)),
             np.multiply(
                 np.repeat(target_vals_cos, np.int_(ntime * 1 / 4), axis=0),
                 np.ones((np.int_(ntime * 1 / 4), ntrials, 1)))),
            axis=0)
        target_s_sin_mem_long = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.zeros((target_zeros, ntrials, 1)),
             np.multiply(
                 np.repeat(target_vals_sin, np.int_(ntime * 1 / 4), axis=0),
                 np.ones((np.int_(ntime * 1 / 4), ntrials, 1)))),
            axis=0)

        target_vals_multi_mem = np.where(
            cos_angles2 + sin_angles2 >= cos_angles + sin_angles,
            cos_angles2 + sin_angles2, cos_angles + sin_angles)
        targets_s_multi_mem = np.concatenate(
            (np.zeros((zeros_beginning, ntrials, 1)),
             np.zeros((np.int_(ntime * 1 / 4), ntrials, 1)),
             np.zeros((target_zeros, ntrials, 1)),
             np.multiply(
                 np.repeat(
                     target_vals_multi_mem, np.int_(ntime * 1 / 4), axis=0),
                 np.ones((np.int_(ntime * 1 / 4), ntrials, 1)))),
            axis=0)
        targets_s_mem_dm = (target_s_cos_mem_long, target_s_sin_mem_long,
                            targets_s_multi_mem)

    ###### Generate input/target variables
    if task_select == 'mem_module':  #Holds input stim during target period only
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_short[0], inputs_s_mem_short[1]),
            axis=2)
        targets = np.concatenate(
            (input_f, targets_s_mem_long[0], targets_s_mem_long[1]), axis=2)
    elif task_select == 'integration_module':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_int[0], inputs_int[1]), axis=2)
        targets = np.concatenate((input_f, targets_int[0], targets_int[1]),
                                 axis=2)

    elif task_select == 'delay_pro':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s[0], inputs_s[1], inputs_r[1],
             inputs_r[0], inputs_r[0], inputs_r[0]),
            axis=2)
        targets = np.concatenate((input_f, targets_s[0], targets_s[1]), axis=2)
    elif task_select == 'delay_anti':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s[0], inputs_s[1], inputs_r[0],
             inputs_r[1], inputs_r[0], inputs_r[0]),
            axis=2)
        targets = np.concatenate(
            (input_f, targets_opposite[0], targets_opposite[1]), axis=2)
    elif task_select == 'mem_pro':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem[0], inputs_s_mem[1], inputs_r[0],
             inputs_r[0], inputs_r[1], inputs_r[0]),
            axis=2)
        targets = np.concatenate((input_f, targets_s[0], targets_s[1]), axis=2)
    elif task_select == 'mem_anti':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem[0], inputs_s_mem[1], inputs_r[0],
             inputs_r[0], inputs_r[0], inputs_r[1]),
            axis=2)
        targets = np.concatenate(
            (input_f, targets_opposite[0], targets_opposite[1]), axis=2)
    elif task_select == 'mem_dm1':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_dm[0], inputs_r[0], inputs_r[1],
             inputs_r[1], inputs_r[1]),
            axis=2)
        targets = np.concatenate((input_f, targets_s_mem_dm[0]), axis=2)
    elif task_select == 'mem_dm2':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_dm[1], inputs_r[1], inputs_r[0],
             inputs_r[1], inputs_r[1]),
            axis=2)
        targets = np.concatenate((input_f, targets_s_mem_dm[1]), axis=2)
    elif task_select == 'context_mem_dm1':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_dm[0], inputs_s_mem_dm[1],
             inputs_r[1], inputs_r[1], inputs_r[0], inputs_r[1]),
            axis=2)
        targets = np.concatenate((input_f, targets_s_mem_dm[0]), axis=2)
    elif task_select == 'context_mem_dm2':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_dm[0], inputs_s_mem_dm[1],
             inputs_r[1], inputs_r[1], inputs_r[1], inputs_r[0]),
            axis=2)
        targets = np.concatenate((input_f, targets_s_mem_dm[1]), axis=2)
    elif task_select == 'multi_mem':
        inputs = np.concatenate(
            (inputs_f_plusnoise, inputs_s_mem_dm[0], inputs_s_mem_dm[1],
             inputs_r[1], inputs_r[1], inputs_r[1], inputs_r[1]),
            axis=2)
        targets = np.concatenate((input_f, targets_s_mem_dm[2]), axis=2)

    inputs = np.array(inputs)
    targets = np.array(targets)

    #Shuffle elements
    shuffle_idx = random.randint(subkeys[10], (ntrials, ), 0, ntrials)
    inputs = inputs[:, shuffle_idx, :]
    targets = targets[:, shuffle_idx, :]

    #Plot some examples
    if do_plot:

        plot_linewidth = 8
        alpha_val = 1
        trial_1 = 0
        trial_2 = 25
        trial_3 = 50
        trial_4 = 75
        fontsize = 'xx-large'
        fontfam = 'sans-serif'
        time = np.linspace(0, ntime + zeros_beginning, ntime + zeros_beginning)

        if task_select == 'delay_pro' or task_select == 'delay_anti' or task_select == 'mem_pro' or task_select == 'mem_anti':
            #Check dims
            # print('Shape Trial Mat:',np.shape(task_trials_batch))
            # print('Shape Inputs:',np.shape(inputs))
            # print('Shape Targets:',np.shape(targets))

            ## PLOT GENERAL INPUTS/OUTPUTS
            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.title('FIXATION INPUT')
            plt.plot(time,
                     inputs_f_plusnoise[:, trial_1, 0],
                     'royalblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Fixation'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.title('STIMULUS SIMPLE')
            plt.plot(time,
                     input_s_cos_plusnoise[:, trial_1, 0],
                     'darkviolet',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     input_s_sin_plusnoise[:, trial_1, 0],
                     'fuchsia',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);plt.ylim([-1.2,1.2])
            plt.xlabel('Time', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Input1', 'Input2'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.title('STIMULUS MEMORY')
            plt.plot(time,
                     input_s_cos_mem_plusnoise[:, trial_1, 0],
                     'darkviolet',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     input_s_sin_mem_plusnoise[:, trial_1, 1],
                     'fuchsia',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);plt.ylim([-1.2,1.2])
            plt.xlabel('Time', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Cos_Mem', 'Sin_Mem'],
                       frameon=False,
                       fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.title('RULE INPUT')
            plt.plot(time,
                     input_r_one_plusnoise[:, trial_1, 0],
                     'black',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     input_r_zero_plusnoise[:, trial_1, 0],
                     'grey',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);plt.ylim([-1.2,1.2])
            plt.xlabel('Time', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Rule 1', 'Rule 0'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.title('TARGET')
            plt.plot(time,
                     target_cos[:, trial_1, 0],
                     'red',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     target_sin[:, trial_1, 0],
                     'tomato',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);plt.ylim([-1.2,1.2])
            plt.xlabel('Time', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target Cos', 'Target Sin'],
                       frameon=False,
                       fontsize=fontsize)

            #SAVE FIGS
            if do_save: save_figs_multiformats(save_name, '%s' % (task_select))

            plt.show()

            ####PLOT PARTICULAR TASK for one trial
            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 0],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Fixation'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 1],
                     'darkviolet',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     inputs[:, trial_1, 2],
                     'fuchsia',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Input 1', 'Input 2'],
                       frameon=False,
                       fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 3],
                     'black',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     inputs[:, trial_1, 4],
                     'grey',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     inputs[:, trial_1, 5],
                     'teal',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     inputs[:, trial_1, 6],
                     'darkturquoise',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Rule 1', 'Rule 2', 'Rule 3', 'Rule 4'],
                       frameon=False,
                       fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 0],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 1'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 1],
                     'maroon',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 2'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 2],
                     'crimson',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 3'], frameon=False, fontsize=fontsize)

        elif task_select == 'mem_module':

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 0],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Input 1'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 1],
                     'darkviolet',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Input 2'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     inputs[:, trial_1, 2],
                     'fuchsia',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Input 3'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 0],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 1'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 1],
                     'darkviolet',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 2'], frameon=False, fontsize=fontsize)

            plt.figure(figsize=(10, 7))
            plt.subplot(111)
            plt.box(on=None)
            plt.plot(time,
                     targets[:, trial_1, 2],
                     'fuchsia',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.xlabel('Time (ms)', fontfamily=fontfam, fontsize=fontsize)
            plt.ylabel('Amplitude', fontfamily=fontfam, fontsize=fontsize)
            plt.legend(['Target 3'], frameon=False, fontsize=fontsize)

        elif task_select == 'mem_dm1' or task_select == 'mem_dm2':

            plt.figure(figsize=(20, 10))

            plt.subplot(121)
            plt.plot(time,
                     inputs[:, trial_1, 0],
                     'b',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     targets[:, trial_1, 0],
                     'r',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            plt.legend(['Input Fixation', 'Target Fixation'], frameon=False)

            plt.subplot(122)
            plt.plot(time,
                     inputs[:, trial_1, 1],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     targets[:, trial_1, 1],
                     'r',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            plt.legend(['Input Stim', 'Target Stim'], frameon=False)

        elif task_select == 'context_mem_dm1' or task_select == 'context_mem_dm2' or task_select == 'multi_mem':

            plt.figure(figsize=(20, 10))

            plt.subplot(131)
            plt.plot(time,
                     inputs[:, trial_1, 0],
                     'b',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     targets[:, trial_1, 0],
                     'r',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            plt.legend(['Input Fixation', 'Target Fixation'], frameon=False)

            plt.subplot(132)
            plt.plot(time,
                     inputs[:, trial_1, 1],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     inputs[:, trial_1, 2],
                     'mediumslateblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     targets[:, trial_1, 1],
                     'r',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            plt.legend(['Input Stim1', 'Input Stim 2', 'Target Stim'],
                       frameon=False)

            plt.subplot(133)
            plt.plot(time,
                     inputs[:, trial_1, 2],
                     'cornflowerblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            plt.plot(time,
                     targets[:, trial_1, 2],
                     'mediumslateblue',
                     linewidth=plot_linewidth,
                     alpha=alpha_val)
            #plt.xlim([0,T]);#plt.ylim([-1.2,1.2])
            plt.xlabel('Time')
            plt.ylabel('Amplitude')
            plt.legend(['Input Stim2', 'Target Stim2'], frameon=False)

            plt.show()

    return inputs, targets