コード例 #1
0
def easy_activity_plot(model_dir, rule):
    """A simple plot of neural activity from one task.

    Args:
        model_dir: directory where model file is saved
        rule: string, the rule to plot
    """

    model = Model(model_dir)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()

        trial = generate_trials(rule, hp, mode='test')
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)
        # All matrices have shape (n_time, n_condition, n_neuron)

    # Take only the one example trial
    i_trial = 0

    for activity, title in zip([trial.x, h, y_hat],
                               ['input', 'recurrent', 'output']):
        plt.figure()
        plt.imshow(activity[:, i_trial, :].T,
                   aspect='auto',
                   cmap='hot',
                   interpolation='none',
                   origin='lower')
        plt.title(title)
        plt.colorbar()
        plt.show()
コード例 #2
0
def quick_statespace(model_dir):
    """Quick state space analysis using simply PCA."""
    rules = ['contextdm1', 'contextdm2']
    h_lastts = dict()
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()
        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            lastt = trial.epochs['stim1'][-1]
            h_lastts[rule] = h[lastt,:,:]

    from sklearn.decomposition import PCA
    model = PCA(n_components=5)
    model.fit(np.concatenate(h_lastts.values(), axis=0))
    fig = plt.figure(figsize=(2,2))
    ax = fig.add_axes([.3, .3, .6, .6])
    for rule, color in zip(rules, ['red', 'blue']):
        data_trans = model.transform(h_lastts[rule])
        ax.scatter(data_trans[:, 0], data_trans[:, 1], s=1,
                   label=rule_name[rule], color=color)
    plt.tick_params(axis='both', which='major', labelsize=7)
    ax.set_xlabel('PC 1', fontsize=7)
    ax.set_ylabel('PC 2', fontsize=7)
    lg = ax.legend(fontsize=7, ncol=1, bbox_to_anchor=(1,0.3),
                   loc=1, frameon=False)
    if save:
        plt.savefig('figure/choiceatt_quickstatespace.pdf',transparent=True)
コード例 #3
0
def run_network_replacerule(model_dir, rule, replace_rule, rule_strength):
    """Run the network but with replaced rule input weights.

    Args:
        model_dir: model directory
        rule: the rule to test on
        replace_rule: a list of rule input units to use
        rule_strength: the relative strength of each replace rule unit
    """
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        # Get performance
        batch_size_test = 1000
        n_rep = 20
        batch_size_test_rep = int(batch_size_test / n_rep)
        perf_rep = list()
        for i_rep in range(n_rep):
            trial = generate_trials(rule,
                                    hp,
                                    'random',
                                    batch_size=batch_size_test_rep,
                                    replace_rule=replace_rule,
                                    rule_strength=rule_strength)
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            y_hat_test = sess.run(model.y_hat, feed_dict=feed_dict)

            perf_rep.append(np.mean(get_perf(y_hat_test, trial.y_loc)))

    return np.mean(perf_rep), rule_strength
コード例 #4
0
    def _compute_H(self, model, rule, trial, sess,):

        feed_dict = tools.gen_feed_dict(model, trial, self.hp)
        h = sess.run(model.h, feed_dict=feed_dict)

        fname = os.path.join(model.model_dir, 'H_'+rule+'.pkl')
        with open(fname, 'wb') as f:
            pickle.dump(h, f)
コード例 #5
0
    def __init__(self, model_dir, rules=None):
        """Initialization.

        Args:
            model_dir: str, model directory
            rules: None or a list of rules
        """
        # Stimulus-averaged traces
        h_stimavg_byrule = OrderedDict()
        h_stimavg_byepoch = OrderedDict()
        # Last time points of epochs
        h_lastt_byepoch = OrderedDict()

        model = Model(model_dir)
        hp = model.hp

        if rules is None:
            # Default value
            rules = hp['rules']
        n_rules = len(rules)

        with tf.Session() as sess:
            model.restore()

            for rule in rules:
                trial = generate_trials(rule=rule, hp=hp, mode='test')
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                h = sess.run(model.h, feed_dict=feed_dict)

                # Average across stimulus conditions
                h_stimavg = h.mean(axis=1)

                # dt_new = 50
                # every_t = int(dt_new/hp['dt'])

                t_start = int(
                    500 / hp['dt'])  # Important: Ignore the initial transition
                # Average across stimulus conditions
                h_stimavg_byrule[rule] = h_stimavg[t_start:, :]

                for e_name, e_time in trial.epochs.items():
                    if 'fix' in e_name:
                        continue

                    # if ('fix' not in e_name) and ('go' not in e_name):
                    # Take epoch
                    e_time_start = e_time[0] - 1 if e_time[0] > 0 else 0
                    h_stimavg_byepoch[(
                        rule, e_name)] = h_stimavg[e_time_start:e_time[1], :]
                    # Take last time point from epoch
                    # h_all_byepoch[(rule, e_name)] = np.mean(h[e_time[0]:e_time[1],:,:][-1], axis=1)
                    h_lastt_byepoch[(rule, e_name)] = h[e_time[1], :, :]

        self.rules = rules
        self.h_stimavg_byrule = h_stimavg_byrule
        self.h_stimavg_byepoch = h_stimavg_byepoch
        self.h_lastt_byepoch = h_lastt_byepoch
        self.model_dir = model_dir
コード例 #6
0
def _psychometric_dm(model_dir, rule, params_list, batch_shape):
    """Base function for computing psychometric performance in 2AFC tasks

    Args:
        model_dir : model name
        rule : task to analyze
        params_list : a list of parameter dictionaries used for the psychometric mode
        batch_shape : shape of each batch. Each batch should have shape (n_rep, ...)
        n_rep is the number of repetitions that will be averaged over

    Return:
        ydatas: list of performances
    """
    print('Starting psychometric analysis of the {:s} task...'.format(
        rule_name[rule]))

    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        ydatas = list()
        for params in params_list:

            trial = generate_trials(rule, hp, 'psychometric', params=params)
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            y_loc_sample = sess.run(model.y_hat_loc, feed_dict=feed_dict)
            y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape)

            stim1_locs_ = np.reshape(params['stim1_locs'], batch_shape)
            stim2_locs_ = np.reshape(params['stim2_locs'], batch_shape)

            # Average over the first dimension of each batch
            choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(
                axis=0)
            choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(
                axis=0)
            ydatas.append(choose1 / (choose1 + choose2))

    return ydatas
コード例 #7
0
def do_eval_test(sess, model, rule):
    """Do evaluation.

    Args:
        sess: tensorflow session
        model: Model class instance
        rule_train: string or list of strings, the rules being trained
    """
    hp = model.hp

    trial = generate_trials(rule, hp, 'test')
    feed_dict = tools.gen_feed_dict(model, trial, hp)
    c_lsq, c_reg, y_hat_test = sess.run(
        [model.cost_lsq, model.cost_reg, model.y_hat], feed_dict=feed_dict)

    # Cost is first summed over time,
    # and averaged across batch and units
    # We did the averaging over time through c_mask
    perf_test = np.mean(get_perf(y_hat_test, trial.y_loc))
    sys.stdout.flush()

    return c_lsq, c_reg, perf_test
コード例 #8
0
def activity_histogram(model_dir,
                       rules,
                       title=None,
                       save_name=None):
    """Plot the activity histogram."""

    if isinstance(rules, str):
        rules = [rules]

    h_all = None
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        t_start = int(500/hp['dt'])

        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            h = h[t_start:, :, :]
            if h_all is None:
                h_all = h
            else:
                h_all = np.concatenate((h_all, h), axis=1)

    # var = h_all.var(axis=0).mean(axis=0)
    # ind = var > 1e-2
    # h_plot = h_all[:, :, ind].flatten()
    h_plot = h_all.flatten()

    fig = plt.figure(figsize=(1.5, 1.2))
    ax = fig.add_axes([0.2, 0.2, 0.7, 0.6])
    ax.hist(h_plot, bins=20, density=True)
    ax.set_xlabel('Activity', fontsize=7)
    [ax.spines[s].set_visible(False) for s in ['left', 'top', 'right']]
    ax.set_yticks([])
コード例 #9
0
    def lesions(self):
        labels = self.labels

        from network import get_perf
        from task import generate_trials

        # The first will be the intact network
        lesion_units_list = [None]
        for il, l in enumerate(self.unique_labels):
            ind_l = np.where(labels == l)[0]
            # In original indices
            lesion_units_list += [self.ind_active[ind_l]]

        perfs_store_list = list()
        perfs_changes = list()
        cost_store_list = list()
        cost_changes = list()

        for i, lesion_units in enumerate(lesion_units_list):
            model = Model(self.model_dir)
            hp = model.hp
            with tf.Session() as sess:
                model.restore()
                model.lesion_units(sess, lesion_units)

                perfs_store = list()
                cost_store = list()
                for rule in self.rules:
                    n_rep = 16
                    batch_size_test = 256
                    batch_size_test_rep = int(batch_size_test / n_rep)
                    clsq_tmp = list()
                    perf_tmp = list()
                    for i_rep in range(n_rep):
                        trial = generate_trials(rule,
                                                hp,
                                                'random',
                                                batch_size=batch_size_test_rep)
                        feed_dict = tools.gen_feed_dict(model, trial, hp)
                        y_hat_test, c_lsq = sess.run(
                            [model.y_hat, model.cost_lsq], feed_dict=feed_dict)

                        # Cost is first summed over time, and averaged across batch and units
                        # We did the averaging over time through c_mask

                        # IMPORTANT CHANGES: take overall mean
                        perf_test = np.mean(get_perf(y_hat_test, trial.y_loc))
                        clsq_tmp.append(c_lsq)
                        perf_tmp.append(perf_test)

                    perfs_store.append(np.mean(perf_tmp))
                    cost_store.append(np.mean(clsq_tmp))

            perfs_store = np.array(perfs_store)
            cost_store = np.array(cost_store)

            perfs_store_list.append(perfs_store)
            cost_store_list.append(cost_store)

            if i > 0:
                perfs_changes.append(perfs_store - perfs_store_list[0])
                cost_changes.append(cost_store - cost_store_list[0])

        perfs_changes = np.array(perfs_changes)
        cost_changes = np.array(cost_changes)

        return perfs_changes, cost_changes
コード例 #10
0
ファイル: train.py プロジェクト: eiroW/FDM
def train_sequential(
    model_dir,
    rule_trains,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    seed=0,
):
    '''Train the network sequentially.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    hp['rules'] = [r for rs in rule_trains for r in rs]

    # Number of training iterations for each rule
    rule_train_iters = [len(r) * max_steps for r in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Using continual learning or not
    c, ksi = hp['c_intsyn'], hp['ksi_intsyn']

    # Build the model
    model = Model(model_dir, hp=hp)

    grad_unreg = tf.gradients(model.cost_lsq, model.var_list)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    # tensorboard summaries
    placeholders = list()
    for v_name in ['Omega0', 'omega0', 'vdelta']:
        for v in model.var_list:
            placeholder = tf.placeholder(tf.float32, shape=v.shape)
            tf.summary.histogram(v_name + '/' + v.name, placeholder)
            placeholders.append(placeholder)
    merged = tf.summary.merge_all()
    test_writer = tf.summary.FileWriter(model_dir + '/tb')

    def relu(x):
        return x * (x > 0.)

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        for i_rule_train, rule_train in enumerate(hp['rule_trains']):
            step = 0

            # At the beginning of new tasks
            # Only if using intelligent synapses
            v_current = sess.run(model.var_list)

            if i_rule_train == 0:
                v_anc0 = v_current
                Omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
                omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
                v_delta = [np.zeros(v.shape, dtype='float32') for v in v_anc0]
            elif c > 0:
                v_anc0_prev = v_anc0
                v_anc0 = v_current
                v_delta = [
                    v - v_prev for v, v_prev in zip(v_anc0, v_anc0_prev)
                ]

                # Make sure all elements in omega0 are non-negative
                # Penalty
                Omega0 = [
                    relu(O + o / (v_d**2 + ksi))
                    for O, o, v_d in zip(Omega0, omega0, v_delta)
                ]

                # Update cost
                model.cost_reg = tf.constant(0.)
                for v, w, v_val in zip(model.var_list, Omega0, v_current):
                    model.cost_reg += c * tf.reduce_sum(
                        tf.multiply(tf.constant(w),
                                    tf.square(v - tf.constant(v_val))))
                model.set_optimizer()

            # Store Omega0 to tf summary
            feed_dict = dict(zip(placeholders, Omega0 + omega0 + v_delta))
            summary = sess.run(merged, feed_dict=feed_dict)
            test_writer.add_summary(summary, i_rule_train)

            # Reset
            omega0 = [np.zeros(v.shape, dtype='float32') for v in v_anc0]

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                rule_train_now = hp['rng'].choice(rule_train)
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # Continual learning with intelligent synapses
                v_prev = v_current

                # This will compute the gradient BEFORE train step
                _, v_grad = sess.run([model.train_step, grad_unreg],
                                     feed_dict=feed_dict)
                # Get the weight after train step
                v_current = sess.run(model.var_list)

                # Update synaptic importance
                omega0 = [
                    o - (v_c - v_p) * v_g for o, v_c, v_p, v_g in zip(
                        omega0, v_current, v_prev, v_grad)
                ]

                step += 1
                step_total += 1

        print("Optimization Finished!")
コード例 #11
0
ファイル: train.py プロジェクト: eiroW/FDM
def train_rule_only(
    model_dir,
    rule_trains,
    max_steps,
    hp=None,
    ruleset='all',
    seed=0,
):
    '''Customized training function.

    The network sequentially but only train rule for the second set.
    First train the network to perform tasks in group 1, then train on group 2.
    When training group 2, only rule connections are being trained.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    hp['rules'] = [r for rs in rule_trains for r in rs]

    # Number of training iterations for each rule
    if hasattr(max_steps, '__iter__'):
        rule_train_iters = max_steps
    else:
        rule_train_iters = [len(r) * max_steps for r in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Build the model
    model = Model(model_dir, hp=hp)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        for i_rule_train, rule_train in enumerate(hp['rule_trains']):
            step = 0

            if i_rule_train == 0:
                display_step = 200
            else:
                display_step = 50

            if i_rule_train > 0:
                # var_list = [v for v in model.var_list
                #             if ('input' in v.name) and ('rnn' not in v.name)]
                var_list = [
                    v for v in model.var_list if 'rule_input' in v.name
                ]
                model.set_optimizer(var_list=var_list)

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                rule_train_now = hp['rng'].choice(rule_train)
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # This will compute the gradient BEFORE train step
                _ = sess.run(model.train_step, feed_dict=feed_dict)

                step += 1
                step_total += 1

        print("Optimization Finished!")
コード例 #12
0
ファイル: train.py プロジェクト: eiroW/FDM
def do_eval(sess, model, log, rule_train):
    """Do evaluation.

    Args:
        sess: tensorflow session
        model: Model class instance
        log: dictionary that stores the log
        rule_train: string or list of strings, the rules being trained
    """
    hp = model.hp
    if not hasattr(rule_train, '__iter__'):
        rule_name_print = rule_train
    else:
        rule_name_print = ' & '.join(rule_train)

    print('Trial {:7d}'.format(log['trials'][-1]) +
          '  | Time {:0.2f} s'.format(log['times'][-1]) + '  | Now training ' +
          rule_name_print)

    for rule_test in hp['rules']:
        n_rep = 16
        batch_size_test_rep = int(hp['batch_size_test'] / n_rep)
        clsq_tmp = list()
        creg_tmp = list()
        perf_tmp = list()
        for i_rep in range(n_rep):
            trial = generate_trials(rule_test,
                                    hp,
                                    'random',
                                    batch_size=batch_size_test_rep)
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            c_lsq, c_reg, y_hat_test = sess.run(
                [model.cost_lsq, model.cost_reg, model.y_hat],
                feed_dict=feed_dict)

            # Cost is first summed over time,
            # and averaged across batch and units
            # We did the averaging over time through c_mask
            perf_test = np.mean(get_perf(y_hat_test, trial.y_loc))
            clsq_tmp.append(c_lsq)
            creg_tmp.append(c_reg)
            perf_tmp.append(perf_test)

        log['cost_' + rule_test].append(np.mean(clsq_tmp, dtype=np.float64))
        log['creg_' + rule_test].append(np.mean(creg_tmp, dtype=np.float64))
        log['perf_' + rule_test].append(np.mean(perf_tmp, dtype=np.float64))
        print('{:15s}'.format(rule_test) +
              '| cost {:0.6f}'.format(np.mean(clsq_tmp)) +
              '| c_reg {:0.6f}'.format(np.mean(creg_tmp)) +
              '  | perf {:0.2f}'.format(np.mean(perf_tmp)))
        sys.stdout.flush()

    # TODO: This needs to be fixed since now rules are strings
    if hasattr(rule_train, '__iter__'):
        rule_tmp = rule_train
    else:
        rule_tmp = [rule_train]
    perf_tests_mean = np.mean([log['perf_' + r][-1] for r in rule_tmp])
    log['perf_avg'].append(perf_tests_mean)

    perf_tests_min = np.min([log['perf_' + r][-1] for r in rule_tmp])
    log['perf_min'].append(perf_tests_min)

    # Saving the model
    model.save()
    tools.save_log(log)

    return log
コード例 #13
0
ファイル: train.py プロジェクト: eiroW/FDM
def train(
    model_dir,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    rule_trains=None,
    rule_prob_map=None,
    seed=0,
    rich_output=False,
    load_dir=None,
    trainables=None,
):
    """Train the network.

    Args:
        model_dir: str, training directory
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps
        display_step: int, display steps
        ruleset: the set of rules to train
        rule_trains: list of rules to train, if None then all rules possible
        rule_prob_map: None or dictionary of relative rule probability
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    """

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)

    # Rules to train and test. Rules in a set are trained together
    if rule_trains is None:
        # By default, training all rules available to this ruleset
        hp['rule_trains'] = task.rules_dict[ruleset]
    else:
        hp['rule_trains'] = rule_trains
    hp['rules'] = hp['rule_trains']

    # Assign probabilities for rule_trains.
    if rule_prob_map is None:
        rule_prob_map = dict()

    # Turn into rule_trains format
    hp['rule_probs'] = None
    if hasattr(hp['rule_trains'], '__iter__'):
        # Set default as 1.
        rule_prob = np.array(
            [rule_prob_map.get(r, 1.) for r in hp['rule_trains']])
        hp['rule_probs'] = list(rule_prob / np.sum(rule_prob))
    tools.save_hp(hp, model_dir)

    # Build the model
    model = Model(model_dir, hp=hp)

    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    with tf.Session() as sess:
        if load_dir is not None:
            model.restore(load_dir)  # complete restore
        else:
            # Assume everything is restored
            sess.run(tf.global_variables_initializer())

        # Set trainable parameters
        if trainables is None or trainables == 'all':
            var_list = model.var_list  # train everything
        elif trainables == 'input':
            # train all nputs
            var_list = [
                v for v in model.var_list
                if ('input' in v.name) and ('rnn' not in v.name)
            ]
        elif trainables == 'rule':
            # train rule inputs only
            var_list = [v for v in model.var_list if 'rule_input' in v.name]
        else:
            raise ValueError('Unknown trainables')
        model.set_optimizer(var_list=var_list)

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            anchor_ws = sess.run(model.weight_list)
            for w, w_val in zip(model.weight_list, anchor_ws):
                model.cost_reg += (hp['l2_weight_init'] *
                                   tf.nn.l2_loss(w - w_val))

            model.set_optimizer(var_list=var_list)

        # partial weight training
        if ('p_weight_train' in hp and (hp['p_weight_train'] is not None)
                and hp['p_weight_train'] < 1.0):
            for w in model.weight_list:
                w_val = sess.run(w)
                w_size = sess.run(tf.size(w))
                w_mask_tmp = np.linspace(0, 1, w_size)
                hp['rng'].shuffle(w_mask_tmp)
                ind_fix = w_mask_tmp > hp['p_weight_train']
                w_mask = np.zeros(w_size, dtype=np.float32)
                w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
                w_mask = tf.constant(w_mask)
                w_mask = tf.reshape(w_mask, w.shape)
                model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
            model.set_optimizer(var_list=var_list)

        step = 0
        while step * hp['batch_size_train'] <= max_steps:
            try:
                # Validation
                if step % display_step == 0:
                    log['trials'].append(step * hp['batch_size_train'])
                    log['times'].append(time.time() - t_start)
                    log = do_eval(sess, model, log, hp['rule_trains'])
                    #if log['perf_avg'][-1] > model.hp['target_perf']:
                    #check if minimum performance is above target
                    if log['perf_min'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                    if rich_output:
                        display_rich_output(model, sess, step, log, model_dir)

                # Training
                rule_train_now = hp['rng'].choice(hp['rule_trains'],
                                                  p=hp['rule_probs'])
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                sess.run(model.train_step, feed_dict=feed_dict)

                step += 1

            except KeyboardInterrupt:
                print("Optimization interrupted by user")
                break

        print("Optimization finished!")
コード例 #14
0
def train_sequential_orthogonalized(model_dir,
                                    rule_trains,
                                    hp=None,
                                    max_steps=1e7,
                                    display_step=500,
                                    rich_output=False,
                                    ruleset='mante',
                                    applyProj='both',
                                    seed=0,
                                    nEpisodeBatches=100,
                                    projGrad=True,
                                    alpha=0.001,
                                    fixReadout=False):
    '''Train the network sequentially.

    Args:
        model_dir: str, training directory
        rule_trains: a list of list of tasks to train sequentially
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps for each list of tasks
        display_step: int, display steps
        ruleset: the set of rules to train
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    '''

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)
    hp['rule_trains'] = rule_trains
    # Get all rules by flattening the list of lists
    # hp['rules'] = [r for rs in rule_trains for r in rs]
    hp['rules'] = rule_trains

    # save some other parameters
    hp['alpha_projection'] = alpha
    hp['max_steps'] = max_steps

    # Number of training iterations for each rule
    rule_train_iters = [max_steps for _ in rule_trains]

    tools.save_hp(hp, model_dir)
    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Build the model
    model = Sequential_Model(model_dir,
                             projGrad=projGrad,
                             applyProj=applyProj,
                             hp=hp)

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    def relu(x):
        return x * (x > 0.)

    # -------------------------------------------------------

    # Use customized session that launches the graph as well
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            raise NotImplementedError()

        # Looping
        step_total = 0
        taskNumber = 0

        if fixReadout is True:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
            ]
        else:
            my_var_list = [
                var for var in model.var_list
                if 'rnn/leaky_rnn_cell/kernel:0' in var.name
                or 'output/weights:0' in var.name
            ]

        # initialise projection matrices
        input_proj = tf.zeros(
            (hp['n_rnn'] + hp['n_input'], hp['n_rnn'] + hp['n_input']))
        activity_proj = tf.zeros((hp['n_rnn'], hp['n_rnn']))
        output_proj = tf.zeros((hp['n_output'], hp['n_output']))
        recurrent_proj = tf.zeros((hp['n_rnn'], hp['n_rnn']))

        for i_rule_train, rule_train in enumerate(hp['rule_trains']):

            step = 0

            model.set_optimizer(activity_proj=activity_proj,
                                input_proj=input_proj,
                                output_proj=output_proj,
                                recurrent_proj=recurrent_proj,
                                taskNumber=taskNumber,
                                var_list=my_var_list,
                                alpha=alpha)

            # Keep training until reach max iterations
            while (step * hp['batch_size_train'] <=
                   rule_train_iters[i_rule_train]):
                # Validation
                if step % display_step == 0:
                    trial = step_total * hp['batch_size_train']
                    log['trials'].append(trial)
                    log['times'].append(time.time() - t_start)
                    log['rule_now'].append(rule_train)
                    log = do_eval(sess, model, log, rule_train)
                    if log['perf_avg'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                # Training
                # rule_train_now = hp['rng'].choice(rule_train)

                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'],
                                        delay_fac=hp['delay_fac'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)

                # update model
                sess.run(model.train_step, feed_dict=feed_dict)

                # # Get the weight after train step
                # v_current = sess.run(model.var_list)

                step += 1
                step_total += 1

                if step % display_step == 0:
                    model.save_ckpt(step_total)

            # ---------- save model after its completed training the current task ----------
            model.save_after_task(taskNumber)

            # ---------- generate task activity for continual learning -------
            trial = generate_trials(rule_train,
                                    hp,
                                    'random',
                                    batch_size=hp['batch_size_test'],
                                    delay_fac=hp['delay_fac'])

            # Generating feed_dict.
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            eval_h, eval_x, eval_y, Wrec, Win = sess.run(
                [model.h, model.x, model.y, model.w_rec, model.w_in],
                feed_dict=feed_dict)
            full_state = np.concatenate([eval_x, eval_h], -1)

            # get weight matrix after current task
            Wfull = np.concatenate([Win, Wrec], 0)

            # joint covariance matrix of input and activity
            Shx_task = compute_covariance(
                np.reshape(full_state, (-1, hp['n_rnn'] + hp['n_input'])).T)

            # covariance matrix of output
            Sy_task = compute_covariance(
                np.reshape(eval_y, (-1, hp['n_output'])).T)

            # get block matrices from Shx_task
            # Sh_task = Shx_task[-hp['n_rnn']:, -hp['n_rnn']:]
            Sh_task = np.matmul(np.matmul(Wfull.T, Shx_task), Wfull)

            # ---------- update stored covariance matrices for continual learning -------
            if taskNumber == 0:
                input_cov = Shx_task
                activity_cov = Sh_task
                output_cov = Sy_task
            else:
                input_cov = taskNumber / (
                    taskNumber + 1) * input_cov + Shx_task / (taskNumber + 1)
                activity_cov = taskNumber / (
                    taskNumber + 1) * activity_cov + Sh_task / (taskNumber + 1)
                output_cov = taskNumber / (
                    taskNumber + 1) * output_cov + Sy_task / (taskNumber + 1)

            # ---------- update projection matrices for continual learning ----------
            activity_proj, input_proj, output_proj, recurrent_proj = compute_projection_matrices(
                activity_cov, input_cov, output_cov,
                input_cov[-hp['n_rnn']:, -hp['n_rnn']:], alpha)

            # update task number
            taskNumber += 1

        print("Optimization Finished!")
コード例 #15
0
    with tf.Session() as sess:
        model.restore()
        model._sigma = 0
        # get all connection weights and biases as tensorflow variables
        var_list = model.var_list
        # evaluate the parameters after training
        params = [sess.run(var) for var in var_list]
        # get hparams
        hparams = model.hp
        # create a trial
        trial = generate_trials(rule,
                                hparams,
                                mode='test',
                                noise_on=False,
                                batch_size=40)  # get feed_dict
        feed_dict = tools.gen_feed_dict(model, trial, hparams)
        # run model
        h_tf, y_hat_tf = sess.run(
            [model.h, model.y_hat],
            feed_dict=feed_dict)  # (n_time, n_condition, n_neuron)

        ##################################################################
        # get shapes
        n_steps, n_trials, n_input_dim = np.shape(trial.x)
        n_rnn = np.shape(h_tf)[2]
        n_output = np.shape(y_hat_tf)[2]

        # Fixed point finder hyperparameters
        # See FixedPointFinder.py for detailed descriptions of available
        # hyperparameters.
        fpf_hps = {}
コード例 #16
0
def pretty_inputoutput_plot(model_dir, rule, save=False, plot_ylabel=False):
    """Plot the input and output activity for a sample trial from one task.

    Args:
        model_dir: model directory
        rule: string, the rule
        save: bool, whether to save plots
        plot_ylabel: bool, whether to plot ylable
    """

    fs = 7

    model = Model(model_dir)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()

        trial = generate_trials(rule, hp, mode='test')
        x, y = trial.x, trial.y
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)

        t_plot = np.arange(x.shape[0]) * hp['dt'] / 1000

        assert hp['num_ring'] == 2

        n_eachring = hp['n_eachring']

        fig = plt.figure(figsize=(1.3, 2))
        ylabels = ['fix. in', 'stim. mod1', 'stim. mod2', 'fix. out', 'out']
        heights = np.array([0.03, 0.2, 0.2, 0.03, 0.2]) + 0.01
        for i in range(5):
            ax = fig.add_axes(
                [0.15,
                 sum(heights[i + 1:] + 0.02) + 0.1, 0.8, heights[i]])
            cmap = 'Purples'
            plt.xticks([])
            ax.tick_params(axis='both',
                           which='major',
                           labelsize=fs,
                           width=0.5,
                           length=2,
                           pad=3)

            if plot_ylabel:
                ax.spines["right"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.spines["top"].set_visible(False)
                ax.xaxis.set_ticks_position('bottom')
                ax.yaxis.set_ticks_position('left')

            else:
                ax.spines["left"].set_visible(False)
                ax.spines["right"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.spines["top"].set_visible(False)
                ax.xaxis.set_ticks_position('none')

            if i == 0:
                plt.plot(t_plot, x[:, 0, 0], color='xkcd:blue')
                if plot_ylabel:
                    plt.yticks([0, 1], ['', ''], rotation='vertical')
                plt.ylim([-0.1, 1.5])
                plt.title(rule_name[rule], fontsize=fs)
            elif i == 1:
                plt.imshow(x[:, 0, 1:1 + n_eachring].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')
                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
            elif i == 2:
                plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')

                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
            elif i == 3:
                plt.plot(t_plot, y[:, 0, 0], color='xkcd:green')
                plt.plot(t_plot, y_hat[:, 0, 0], color='xkcd:blue')
                if plot_ylabel:
                    plt.yticks([0.05, 0.8], ['', ''], rotation='vertical')
                plt.ylim([-0.1, 1.1])
            elif i == 4:
                plt.imshow(y_hat[:, 0, 1:].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')
                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
                plt.xticks([0, y_hat.shape[0]], ['0', '2'])
                plt.xlabel('Time (s)', fontsize=fs, labelpad=-3)
                ax.spines["bottom"].set_visible(True)

            if plot_ylabel:
                plt.ylabel(ylabels[i], fontsize=fs)
            else:
                plt.yticks([])
            ax.get_yaxis().set_label_coords(-0.12, 0.5)

        if save:
            save_name = 'figure/sample_' + rule_name[rule].replace(' ',
                                                                   '') + '.pdf'
            plt.savefig(save_name, transparent=True)
        plt.show()
コード例 #17
0
def schematic_plot(model_dir, rule=None):
    fontsize = 6

    rule = rule or 'dm1'

    model = Model(model_dir, dt=1)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()
        trial = generate_trials(rule, hp, mode='test')
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        x = trial.x
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)

    n_eachring = hp['n_eachring']
    n_hidden = hp['n_rnn']

    # Plot Stimulus
    fig = plt.figure(figsize=(1.0, 1.2))
    heights = np.array([0.06, 0.25, 0.25])
    for i in range(3):
        ax = fig.add_axes(
            [0.2, sum(heights[i + 1:] + 0.1) + 0.05, 0.7, heights[i]])
        cmap = 'Purples'
        plt.xticks([])

        # Fixed style for these plots
        ax.tick_params(axis='both',
                       which='major',
                       labelsize=fontsize,
                       width=0.5,
                       length=2,
                       pad=3)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        if i == 0:
            plt.plot(x[:, 0, 0], color='xkcd:blue')
            plt.yticks([0, 1], ['', ''], rotation='vertical')
            plt.ylim([-0.1, 1.5])
            plt.title('Fixation input', fontsize=fontsize, y=0.9)
        elif i == 1:
            plt.imshow(x[:, 0, 1:1 + n_eachring].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1],
                       [r'0$\degree$', '', r'360$\degree$'],
                       rotation='vertical')
            plt.title('Stimulus mod 1', fontsize=fontsize, y=0.9)
        elif i == 2:
            plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], ['', '', ''],
                       rotation='vertical')
            plt.title('Stimulus mod 2', fontsize=fontsize, y=0.9)
        ax.get_yaxis().set_label_coords(-0.12, 0.5)
    plt.savefig('figure/schematic_input.pdf', transparent=True)
    plt.show()

    # Plot Rule Inputs
    fig = plt.figure(figsize=(1.0, 0.5))
    ax = fig.add_axes([0.2, 0.3, 0.7, 0.45])
    cmap = 'Purples'
    X = x[:, 0, 1 + 2 * n_eachring:]
    plt.imshow(X.T,
               aspect='auto',
               vmin=0,
               vmax=1,
               cmap=cmap,
               interpolation='none',
               origin='lower')

    plt.xticks([0, X.shape[0]])
    ax.set_xlabel('Time (ms)', fontsize=fontsize, labelpad=-5)

    # Fixed style for these plots
    ax.tick_params(axis='both',
                   which='major',
                   labelsize=fontsize,
                   width=0.5,
                   length=2,
                   pad=3)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_linewidth(0.5)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    plt.yticks([0, X.shape[-1] - 1], ['1', str(X.shape[-1])],
               rotation='vertical')
    plt.title('Rule inputs', fontsize=fontsize, y=0.9)
    ax.get_yaxis().set_label_coords(-0.12, 0.5)

    plt.savefig('figure/schematic_rule.pdf', transparent=True)
    plt.show()

    # Plot Units
    fig = plt.figure(figsize=(1.0, 0.8))
    ax = fig.add_axes([0.2, 0.1, 0.7, 0.75])
    cmap = 'Purples'
    plt.xticks([])
    # Fixed style for these plots
    ax.tick_params(axis='both',
                   which='major',
                   labelsize=fontsize,
                   width=0.5,
                   length=2,
                   pad=3)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    plt.imshow(h[:, 0, :].T,
               aspect='auto',
               cmap=cmap,
               vmin=0,
               vmax=1,
               interpolation='none',
               origin='lower')
    plt.yticks([0, n_hidden - 1], ['1', str(n_hidden)], rotation='vertical')
    plt.title('Recurrent units', fontsize=fontsize, y=0.95)
    ax.get_yaxis().set_label_coords(-0.12, 0.5)
    plt.savefig('figure/schematic_units.pdf', transparent=True)
    plt.show()

    # Plot Outputs
    fig = plt.figure(figsize=(1.0, 0.8))
    heights = np.array([0.1, 0.45]) + 0.01
    for i in range(2):
        ax = fig.add_axes(
            [0.2, sum(heights[i + 1:] + 0.15) + 0.1, 0.7, heights[i]])
        cmap = 'Purples'
        plt.xticks([])

        # Fixed style for these plots
        ax.tick_params(axis='both',
                       which='major',
                       labelsize=fontsize,
                       width=0.5,
                       length=2,
                       pad=3)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        if i == 0:
            plt.plot(y_hat[:, 0, 0], color='xkcd:blue')
            plt.yticks([0.05, 0.8], ['', ''], rotation='vertical')
            plt.ylim([-0.1, 1.1])
            plt.title('Fixation output', fontsize=fontsize, y=0.9)

        elif i == 1:
            plt.imshow(y_hat[:, 0, 1:].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1],
                       [r'0$\degree$', '', r'360$\degree$'],
                       rotation='vertical')
            plt.xticks([])
            plt.title('Response', fontsize=fontsize, y=0.9)

        ax.get_yaxis().set_label_coords(-0.12, 0.5)

    plt.savefig('figure/schematic_outputs.pdf', transparent=True)
    plt.show()
コード例 #18
0
def pretty_singleneuron_plot(model_dir,
                             rules,
                             neurons,
                             epoch=None,
                             save=False,
                             ylabel_firstonly=True,
                             trace_only=False,
                             plot_stim_avg=False,
                             save_name=''):
    """Plot the activity of a single neuron in time across many trials

    Args:
        model_dir:
        rules: rules to plot
        neurons: indices of neurons to plot
        epoch: epoch to plot
        save: save figure?
        ylabel_firstonly: if True, only plot ylabel for the first rule in rules
    """

    if isinstance(rules, str):
        rules = [rules]

    try:
        _ = iter(neurons)
    except TypeError:
        neurons = [neurons]

    h_tests = dict()
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        t_start = int(500 / hp['dt'])

        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            h_tests[rule] = h

    for neuron in neurons:
        h_max = np.max([h_tests[r][t_start:, :, neuron].max() for r in rules])
        for j, rule in enumerate(rules):
            fs = 6
            fig = plt.figure(figsize=(1.0, 0.8))
            ax = fig.add_axes([0.35, 0.25, 0.55, 0.55])
            t_plot = np.arange(
                h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000
            _ = ax.plot(t_plot,
                        h_tests[rule][t_start:, :, neuron],
                        lw=0.5,
                        color='gray')

            if plot_stim_avg:
                # Plot stimulus averaged trace
                _ = ax.plot(np.arange(h_tests[rule][t_start:].shape[0]) *
                            hp['dt'] / 1000,
                            h_tests[rule][t_start:, :, neuron].mean(axis=1),
                            lw=1,
                            color='black')

            if epoch is not None:
                e0, e1 = trial.epochs[epoch]
                e0 = e0 if e0 is not None else 0
                e1 = e1 if e1 is not None else h_tests[rule].shape[0]
                ax.plot([e0, e1], [h_max * 1.15] * 2,
                        color='black',
                        linewidth=1.5)
                figname = 'figure/trace_' + rule_name[
                    rule] + epoch + save_name + '.pdf'
            else:
                figname = 'figure/trace_unit' + str(
                    neuron) + rule_name[rule] + save_name + '.pdf'

            plt.ylim(np.array([-0.1, 1.2]) * h_max)
            plt.xticks([0, np.floor(np.max(t_plot) + 0.01)])
            plt.xlabel('Time (s)', fontsize=fs, labelpad=-5)
            plt.locator_params(axis='y', nbins=4)
            if j > 0 and ylabel_firstonly:
                ax.set_yticklabels([])
            else:
                plt.ylabel('Activitity (a.u.)', fontsize=fs, labelpad=2)
            plt.title('Unit {:d} '.format(neuron) + rule_name[rule],
                      fontsize=5)
            ax.tick_params(axis='both', which='major', labelsize=fs)
            ax.spines["right"].set_visible(False)
            ax.spines["top"].set_visible(False)
            ax.xaxis.set_ticks_position('bottom')
            ax.yaxis.set_ticks_position('left')
            if trace_only:
                ax.spines["left"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.xaxis.set_ticks_position('none')
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_title('')

            if save:
                plt.savefig(figname, transparent=True)
            plt.show()
コード例 #19
0
ファイル: variance.py プロジェクト: liuyuue/RNN_multitask
def _compute_variance_bymodel(model, sess, rules=None, random_rotation=False):
    """Compute variance for all tasks.

        Args:
            model: network.Model instance
            sess: tensorflow session
            rules: list of rules to compute variance, list of strings
            random_rotation: boolean. If True, rotate the neural activity.
        """
    h_all_byrule = OrderedDict()
    h_all_byepoch = OrderedDict()
    hp = model.hp

    if rules is None:
        rules = hp['rules']
    print(rules)

    n_hidden = hp['n_rnn']

    if random_rotation:
        # Generate random orthogonal matrix
        from scipy.stats import ortho_group
        random_ortho_matrix = ortho_group.rvs(dim=n_hidden)

    for rule in rules:
        trial = generate_trials(rule, hp, 'test', noise_on=False)
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        h = sess.run(model.h, feed_dict=feed_dict)
        if random_rotation:
            h = np.dot(h, random_ortho_matrix)  # randomly rotate

        for e_name, e_time in trial.epochs.items():
            if 'fix' not in e_name:  # Ignore fixation period
                h_all_byepoch[(rule, e_name)] = h[e_time[0]:e_time[1], :, :]

        # Ignore fixation period
        h_all_byrule[rule] = h[trial.epochs['fix1'][1]:, :, :]

    # Reorder h_all_byepoch by epoch-first
    keys = list(h_all_byepoch.keys())
    # ind_key_sort = np.lexsort(zip(*keys))
    # Using mergesort because it is stable
    ind_key_sort = np.argsort(list(zip(*keys))[1], kind='mergesort')
    h_all_byepoch = OrderedDict([(keys[i], h_all_byepoch[keys[i]])
                                 for i in ind_key_sort])

    for data_type in ['rule', 'epoch']:
        if data_type == 'rule':
            h_all = h_all_byrule
        elif data_type == 'epoch':
            h_all = h_all_byepoch
        else:
            raise ValueError

        h_var_all = np.zeros((n_hidden, len(h_all.keys())))
        for i, val in enumerate(h_all.values()):
            # val is Time, Batch, Units
            # Variance across time and stimulus
            # h_var_all[:, i] = val[t_start:].reshape((-1, n_hidden)).var(axis=0)
            # Variance acros stimulus, then averaged across time
            h_var_all[:, i] = val.var(axis=1).mean(axis=0)

        result = {'h_var_all': h_var_all, 'keys': list(h_all.keys())}
        save_name = 'variance_' + data_type
        if random_rotation:
            save_name += '_rr'

        fname = os.path.join(model.model_dir, save_name + '.pkl')
        print('Variance saved at {:s}'.format(fname))
        with open(fname, 'wb') as f:
            pickle.dump(result, f)
コード例 #20
0
def psychometric_choicefamily_2D(model_dir,
                                 rule,
                                 lesion_units=None,
                                 n_coh=8,
                                 n_stim_loc=20,
                                 coh_range=0.1):
    # Generate task parameters for choice tasks
    # coh_range = 0.2
    # coh_range = 0.05
    cohs = np.linspace(-coh_range, coh_range, n_coh)

    batch_size = n_stim_loc * n_coh**2
    batch_shape = (n_stim_loc, n_coh, n_coh)
    ind_stim_loc, ind_stim_mod1, ind_stim_mod2 = np.unravel_index(
        range(batch_size), batch_shape)

    # Looping target location
    stim1_locs = 2 * np.pi * ind_stim_loc / n_stim_loc
    stim2_locs = (stim1_locs + np.pi) % (2 * np.pi)

    stim_mod1_cohs = cohs[ind_stim_mod1]
    stim_mod2_cohs = cohs[ind_stim_mod2]

    params_dict = dict()
    params_dict['dm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_strengths' : 1 + stim_mod1_cohs, # Just use mod 1 value
          'stim2_strengths' : 1 - stim_mod1_cohs,
          'stim_time'    : 800
          }
    params_dict['dm2'] = params_dict['dm1']

    params_dict['contextdm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod2_cohs,
          'stim2_mod2_strengths' : 1 - stim_mod2_cohs,
          'stim_time'    : 800
          }

    params_dict['contextdm2'] = params_dict['contextdm1']

    params_dict['contextdelaydm1'] = params_dict['contextdm1']
    params_dict['contextdelaydm1']['stim_time'] = 800
    params_dict['contextdelaydm2'] = params_dict['contextdelaydm1']

    params_dict['multidm'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod1_cohs, # Same as Mod 1
          'stim2_mod2_strengths' : 1 - stim_mod1_cohs,
          'stim_time'    : 800
          }

    params_dict['contextdelaydm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod2_cohs,
          'stim2_mod2_strengths' : 1 - stim_mod2_cohs,
          'stim_time'    : 800
          }

    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()
        model.lesion_units(sess, lesion_units)

        params = params_dict[rule]
        trial = generate_trials(rule, hp, 'psychometric', params=params)
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        y_sample, y_loc_sample = sess.run([model.y_hat, model.y_hat_loc],
                                          feed_dict=feed_dict)

    # Compute the overall performance.
    # Importantly, discard trials where no decision was made
    loc_cor = trial.y_loc[-1]  # last time point, correct locations
    loc_err = (loc_cor + np.pi) % (2 * np.pi)
    choose_cor = (get_dist(y_loc_sample[-1] - loc_cor) < THETA).sum()
    choose_err = (get_dist(y_loc_sample[-1] - loc_err) < THETA).sum()
    perf = choose_cor / (choose_cor + choose_err)

    # Compute the proportion of choosing choice 1 and maintain the batch_shape
    stim1_locs_ = np.reshape(stim1_locs, batch_shape)
    stim2_locs_ = np.reshape(stim2_locs, batch_shape)

    y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape)
    choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(axis=0)
    choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(axis=0)
    prop1s = choose1 / (choose1 + choose2)

    return perf, prop1s, cohs