Esempio n. 1
0
def easy_activity_plot(model_dir, rule):
    """A simple plot of neural activity from one task.

    Args:
        model_dir: directory where model file is saved
        rule: string, the rule to plot
    """

    model = Model(model_dir)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()

        trial = generate_trials(rule, hp, mode='test')
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)
        # All matrices have shape (n_time, n_condition, n_neuron)

    # Take only the one example trial
    i_trial = 0

    for activity, title in zip([trial.x, h, y_hat],
                               ['input', 'recurrent', 'output']):
        plt.figure()
        plt.imshow(activity[:, i_trial, :].T,
                   aspect='auto',
                   cmap='hot',
                   interpolation='none',
                   origin='lower')
        plt.title(title)
        plt.colorbar()
        plt.show()
Esempio n. 2
0
def quick_statespace(model_dir):
    """Quick state space analysis using simply PCA."""
    rules = ['contextdm1', 'contextdm2']
    h_lastts = dict()
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()
        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            lastt = trial.epochs['stim1'][-1]
            h_lastts[rule] = h[lastt,:,:]

    from sklearn.decomposition import PCA
    model = PCA(n_components=5)
    model.fit(np.concatenate(h_lastts.values(), axis=0))
    fig = plt.figure(figsize=(2,2))
    ax = fig.add_axes([.3, .3, .6, .6])
    for rule, color in zip(rules, ['red', 'blue']):
        data_trans = model.transform(h_lastts[rule])
        ax.scatter(data_trans[:, 0], data_trans[:, 1], s=1,
                   label=rule_name[rule], color=color)
    plt.tick_params(axis='both', which='major', labelsize=7)
    ax.set_xlabel('PC 1', fontsize=7)
    ax.set_ylabel('PC 2', fontsize=7)
    lg = ax.legend(fontsize=7, ncol=1, bbox_to_anchor=(1,0.3),
                   loc=1, frameon=False)
    if save:
        plt.savefig('figure/choiceatt_quickstatespace.pdf',transparent=True)
Esempio n. 3
0
def run_network_replacerule(model_dir, rule, replace_rule, rule_strength):
    """Run the network but with replaced rule input weights.

    Args:
        model_dir: model directory
        rule: the rule to test on
        replace_rule: a list of rule input units to use
        rule_strength: the relative strength of each replace rule unit
    """
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        # Get performance
        batch_size_test = 1000
        n_rep = 20
        batch_size_test_rep = int(batch_size_test / n_rep)
        perf_rep = list()
        for i_rep in range(n_rep):
            trial = generate_trials(rule,
                                    hp,
                                    'random',
                                    batch_size=batch_size_test_rep,
                                    replace_rule=replace_rule,
                                    rule_strength=rule_strength)
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            y_hat_test = sess.run(model.y_hat, feed_dict=feed_dict)

            perf_rep.append(np.mean(get_perf(y_hat_test, trial.y_loc)))

    return np.mean(perf_rep), rule_strength
Esempio n. 4
0
def easy_connectivity_plot(model_dir):
    """A simple plot of network connectivity."""

    model = Model(model_dir)
    with tf.Session() as sess:
        model.restore()
        # get all connection weights and biases as tensorflow variables
        var_list = model.var_list
        # evaluate the parameters after training
        params = [sess.run(var) for var in var_list]
        # get name of each variable
        names = [var.name for var in var_list]

    # Plot weights
    for param, name in zip(params, names):
        if len(param.shape) != 2:
            continue

        vmax = np.max(abs(param)) * 0.7
        plt.figure()
        # notice the transpose
        plt.imshow(param.T,
                   aspect='auto',
                   cmap='bwr',
                   vmin=-vmax,
                   vmax=vmax,
                   interpolation='none',
                   origin='lower')
        plt.title(name)
        plt.colorbar()
        plt.xlabel('From')
        plt.ylabel('To')
        plt.show()
Esempio n. 5
0
    def __init__(self, model_dir, rules=None):
        """Initialization.

        Args:
            model_dir: str, model directory
            rules: None or a list of rules
        """
        # Stimulus-averaged traces
        h_stimavg_byrule = OrderedDict()
        h_stimavg_byepoch = OrderedDict()
        # Last time points of epochs
        h_lastt_byepoch = OrderedDict()

        model = Model(model_dir)
        hp = model.hp

        if rules is None:
            # Default value
            rules = hp['rules']
        n_rules = len(rules)

        with tf.Session() as sess:
            model.restore()

            for rule in rules:
                trial = generate_trials(rule=rule, hp=hp, mode='test')
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                h = sess.run(model.h, feed_dict=feed_dict)

                # Average across stimulus conditions
                h_stimavg = h.mean(axis=1)

                # dt_new = 50
                # every_t = int(dt_new/hp['dt'])

                t_start = int(
                    500 / hp['dt'])  # Important: Ignore the initial transition
                # Average across stimulus conditions
                h_stimavg_byrule[rule] = h_stimavg[t_start:, :]

                for e_name, e_time in trial.epochs.items():
                    if 'fix' in e_name:
                        continue

                    # if ('fix' not in e_name) and ('go' not in e_name):
                    # Take epoch
                    e_time_start = e_time[0] - 1 if e_time[0] > 0 else 0
                    h_stimavg_byepoch[(
                        rule, e_name)] = h_stimavg[e_time_start:e_time[1], :]
                    # Take last time point from epoch
                    # h_all_byepoch[(rule, e_name)] = np.mean(h[e_time[0]:e_time[1],:,:][-1], axis=1)
                    h_lastt_byepoch[(rule, e_name)] = h[e_time[1], :, :]

        self.rules = rules
        self.h_stimavg_byrule = h_stimavg_byrule
        self.h_stimavg_byepoch = h_stimavg_byepoch
        self.h_lastt_byepoch = h_lastt_byepoch
        self.model_dir = model_dir
Esempio n. 6
0
def run_simulation(save_name, setting):
    '''Generate simulation data for all trials'''
    tf.reset_default_graph()
    model = Model(save_name, sigma_rec=setting['sigma_rec'], dt=10)

    with tf.Session() as sess:
        model.restore(sess)
        Data = _run_simulation(model, setting)

    return Data
Esempio n. 7
0
def _compute_variance(model_dir, rules=None, random_rotation=False):
    """Compute variance for all tasks.

    Args:
        model_dir: str, the path of the model directory
        rules: list of rules to compute variance, list of strings
        random_rotation: boolean. If True, rotate the neural activity.
    """
    model = Model(model_dir, sigma_rec=0)
    with tf.Session() as sess:
        model.restore()
        _compute_variance_bymodel(model, sess, rules, random_rotation)
    def compute_H(self,  
                rules=None, 
                trial_list=None, 
                recompute=False,):
        
        if rules is not None:
            self.rules = rules
        else:
            self.rules = self.hp['rule_trains']
        
        if trial_list is not None:
            self.trial_list = trial_list
        else:
            self.trial_list = self.log['trials']

        self.in_loc = dict()
        self.in_loc_set = dict()
        self.epoch_info = dict()

        trial_store = dict()
        #self.trial_store = dict()########################## do we really need self.?
        print("Epoch information:")
        for rule in self.rules:
            trial_store[rule] = generate_trials(rule, self.hp, 'test', noise_on=False)
            self.in_loc[rule] = np.array([np.argmax(i) for i in trial_store[rule].input_loc])
            self.in_loc_set[rule] = sorted(set(self.in_loc[rule]))
            self.epoch_info[rule] = trial_store[rule].epochs
            #self.trial_store[rule] = generate_trials(rule, self.hp, 'test', noise_on=False)
            #self.in_loc[rule] = np.array([np.argmax(i) for i in self.trial_store[rule].input_loc])
            print('\t'+rule+':')
            for e_name, e_time in self.epoch_info[rule].items():
                print('\t\t'+e_name+':',e_time)
        
        for trial_num in self.trial_list:
            sub_dir = self.model_dir+'/'+str(trial_num)+'/'
            for rule in self.rules:
                if recompute or not os.path.exists(sub_dir+'H_'+rule+'.pkl'):
                    model = Model(sub_dir, hp=self.hp)
                    with tf.Session() as sess:
                        model.restore()
                        self._compute_H(model, rule, trial_store[rule], sess,)
Esempio n. 9
0
def _psychometric_dm(model_dir, rule, params_list, batch_shape):
    """Base function for computing psychometric performance in 2AFC tasks

    Args:
        model_dir : model name
        rule : task to analyze
        params_list : a list of parameter dictionaries used for the psychometric mode
        batch_shape : shape of each batch. Each batch should have shape (n_rep, ...)
        n_rep is the number of repetitions that will be averaged over

    Return:
        ydatas: list of performances
    """
    print('Starting psychometric analysis of the {:s} task...'.format(
        rule_name[rule]))

    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        ydatas = list()
        for params in params_list:

            trial = generate_trials(rule, hp, 'psychometric', params=params)
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            y_loc_sample = sess.run(model.y_hat_loc, feed_dict=feed_dict)
            y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape)

            stim1_locs_ = np.reshape(params['stim1_locs'], batch_shape)
            stim2_locs_ = np.reshape(params['stim2_locs'], batch_shape)

            # Average over the first dimension of each batch
            choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(
                axis=0)
            choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(
                axis=0)
            ydatas.append(choose1 / (choose1 + choose2))

    return ydatas
def activity_histogram(model_dir,
                       rules,
                       title=None,
                       save_name=None):
    """Plot the activity histogram."""

    if isinstance(rules, str):
        rules = [rules]

    h_all = None
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        t_start = int(500/hp['dt'])

        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            h = h[t_start:, :, :]
            if h_all is None:
                h_all = h
            else:
                h_all = np.concatenate((h_all, h), axis=1)

    # var = h_all.var(axis=0).mean(axis=0)
    # ind = var > 1e-2
    # h_plot = h_all[:, :, ind].flatten()
    h_plot = h_all.flatten()

    fig = plt.figure(figsize=(1.5, 1.2))
    ax = fig.add_axes([0.2, 0.2, 0.7, 0.6])
    ax.hist(h_plot, bins=20, density=True)
    ax.set_xlabel('Activity', fontsize=7)
    [ax.spines[s].set_visible(False) for s in ['left', 'top', 'right']]
    ax.set_yticks([])
Esempio n. 11
0
def networkx_illustration(model_dir):
    import networkx as nx

    model = Model(model_dir)
    with tf.Session() as sess:
        model.restore()
        # get all connection weights and biases as tensorflow variables
        w_rec = sess.run(model.w_rec)

    w_rec_flat = w_rec.flatten()
    ind_sort = np.argsort(abs(w_rec_flat - np.mean(w_rec_flat)))
    n_show = int(0.01 * len(w_rec_flat))
    ind_gone = ind_sort[:-n_show]
    ind_keep = ind_sort[-n_show:]
    w_rec_flat[ind_gone] = 0
    w_rec2 = np.reshape(w_rec_flat, w_rec.shape)
    w_rec_keep = w_rec_flat[ind_keep]
    G = nx.from_numpy_array(abs(w_rec2), create_using=nx.DiGraph())

    color = w_rec_keep
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
    nx.draw(G,
            linewidths=0,
            width=0.1,
            alpha=1.0,
            edge_vmin=-3,
            edge_vmax=3,
            arrows=False,
            pos=nx.circular_layout(G),
            node_color=np.array([99. / 255] * 3),
            node_size=10,
            edge_color=color,
            edge_cmap=plt.cm.RdBu_r,
            ax=ax)
    plt.savefig('figure/illustration_networkx.pdf', transparent=True)
Esempio n. 12
0
    def _plot_inout_connections(self, conn_type):
        """Plot connectivity while sorting by group.

        Args:
            conn_type: str, type of connectivity to plot.
        """

        # Sort data by labels and by input connectivity
        model = Model(self.model_dir)
        hp = model.hp
        with tf.Session() as sess:
            model.restore()
            w_in, w_out = sess.run([model.w_in, model.w_out])

        n_ring = hp['n_eachring']
        groups = ['1', '2', '12']

        # Plot input from stim or output to loc
        if conn_type == 'input':
            w_conn = w_in[1:n_ring+1, :].T
            xlabel = 'Preferred mod 1 input dir.'
            ylabel = 'Conn. weight\n from mod 1'
            lgtitle = 'To group'
        elif conn_type == 'output':
            w_conn = w_out[:, 1:]
            xlabel = 'Preferred output dir.'
            ylabel = 'Conn. weight to output'
            lgtitle = 'From group'
        else:
            raise ValueError('Unknown conn type')

        w_aves = dict()

        for group in groups:
            ind_group  = self.group_ind_orig[group]
            n_group    = len(ind_group)
            w_group = np.zeros((n_group, n_ring))

            for i, ind in enumerate(ind_group):
                tmp = w_conn[ind, :]
                ind_max = np.argmax(tmp)
                w_group[i, :] = np.roll(tmp, int(n_ring/2)-ind_max)

            w_aves[group] = w_group.mean(axis=0)

        fs = 6
        fig = plt.figure(figsize=(1.5, 1.0))
        ax = fig.add_axes([.35, .25, .55, .6])
        for group in groups:
            ax.plot(w_aves[group], color=self.colors[group], label=group, lw=1)
        ax.set_xticks([int(n_ring/2)])
        ax.set_xticklabels([xlabel])
        # ax.set_xlabel(xlabel, fontsize=fs, labelpad=3)
        ax.set_ylabel(ylabel, fontsize=fs)
        lg = ax.legend(title=lgtitle, fontsize=fs, bbox_to_anchor=(1.2,1.2),
                       labelspacing=0.2, loc=1, frameon=False)
        plt.setp(lg.get_title(),fontsize=fs)
        ax.tick_params(axis='both', which='major', labelsize=fs)
        plt.locator_params(axis='y',nbins=3)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        plt.savefig('figure/conn_'+conn_type+'_contextdm.pdf', transparent=True)
Esempio n. 13
0
    def plot_rec_connections(self):
        """Plot connectivity while sorting by group.

        Args:
            conn_type: str, type of connectivity to plot.
        """

        # Sort data by labels and by input connectivity
        model = Model(self.model_dir)
        hp = model.hp
        with tf.Session() as sess:
            model.restore()
            w_in, w_rec = sess.run([model.w_in, model.w_rec])
        w_in, w_rec = w_in.T, w_rec.T

        n_ring = hp['n_eachring']
        groups = ['1', '2', '12']

        w_in_ = (w_in[:, 1:n_ring + 1] + w_in[:, 1+n_ring:2*n_ring+1]) / 2.

        # Plot recurrent connectivity
        # w_rec_group = np.zeros((len(groups), len(groups)))
        # for i1, group1 in enumerate(groups):
        #     for i2, group2 in enumerate(groups):
        #         ind1 = self.group_ind_orig[group1]
        #         ind2 = self.group_ind_orig[group2]
        #         w_rec_group[i2, i1] = w_rec[:, ind1][ind2, :].mean()

        i_pairs = list()
        for i1 in range(len(groups)):
            for i2 in range(len(groups)):
                i_pairs.append((i1, i2))

        pref_diffs_list = list()
        w_recs_list = list()

        w_rec_bygroup = np.zeros((len(groups), len(groups)))

        inds = [self.group_ind_orig[g] for g in groups]
        for i_pair in i_pairs:
            ind1, ind2 = inds[i_pair[0]], inds[i_pair[1]]
            # For each neuron get the preference based on input weight
            # sort by weights
            w_sortby = w_in_
            # w_sortby = w_out_
            prefs1 = np.argmax(w_sortby[ind1, :], axis=1)*2.*np.pi/n_ring
            prefs2 = np.argmax(w_sortby[ind2, :], axis=1)*2.*np.pi/n_ring

            # Compute the pairwise distance based on preference
            # Then get the connection weight between pairs
            pref_diffs = list()
            w_recs = list()
            for i, ind_i in enumerate(ind1):
                for j, ind_j in enumerate(ind2):
                    if ind_j == ind_i:
                        # Excluding self connections, which tend to be positive
                        continue
                    pref_diffs.append(get_dist(prefs1[i]-prefs2[j]))
                    # pref_diffs.append(prefs1[i]-prefs2[j])
                    w_recs.append(w_rec[ind_j, ind_i])
            pref_diffs, w_recs = np.array(pref_diffs), np.array(w_recs)
            pref_diffs_list.append(pref_diffs)
            w_recs_list.append(w_recs)

            w_rec_bygroup[i_pair[1], i_pair[0]] = np.mean(w_recs[pref_diffs<np.pi/6.])

        fs = 6
        vmax = np.ceil(np.max(w_rec_bygroup) * 100) / 100.
        vmin = np.floor(np.min(w_rec_bygroup) * 100) / 100.
        fig = plt.figure(figsize=(1.5, 1.5))
        ax = fig.add_axes([0.2, 0.1, 0.5, 0.5])
        im = ax.imshow(w_rec_bygroup, interpolation='nearest',
                       cmap='coolwarm', aspect='auto', vmin=vmin, vmax=vmax)
        # ax.axis('off')
        ax.xaxis.set_label_position("top")
        ax.xaxis.set_ticks_position("top")
        plt.xticks([0, 1, 2], groups, fontsize=6)
        plt.yticks([0, 1, 2], groups, fontsize=6)
        ax.tick_params('both', length=0)
        ax.set_xlabel('From', fontsize=fs, labelpad=2)
        ax.set_ylabel('To', fontsize=fs, labelpad=2)
        for s in ['right', 'left', 'top', 'bottom']:
            ax.spines[s].set_visible(False)

        ax = fig.add_axes([0.72, 0.1, 0.03, 0.5])
        cb = plt.colorbar(im, cax=ax, ticks=[vmin,vmax])
        cb.outline.set_linewidth(0.5)
        cb.set_label(r'Rec. weight', fontsize=fs, labelpad=-7)
        plt.tick_params(axis='both', which='major', labelsize=fs)
        plt.locator_params(nbins=3)

        plt.savefig('figure/conn_rec_contextdm.pdf', transparent=True)
Esempio n. 14
0
def compute_taskspace(model_dir, setup, restore=False, representation='rate'):
    if setup == 1:
        rules = ['fdgo', 'fdanti', 'delaygo', 'delayanti']
    elif setup == 2:
        rules = [
            'contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2'
        ]
    elif setup == 3:
        rules = ['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo']
    elif setup == 4:
        rules = [
            'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'contextdm1',
            'contextdm2', 'multidm'
        ]
    elif setup == 5:
        rules = [
            'contextdelaydm1',
            'contextdelaydm2',
            'multidelaydm',
            'delaydm1',
            'delaydm2',
            'contextdm1',
            'contextdm2',
            'multidm',
            'dm1',
            'dm2',
        ]
    elif setup == 6:
        rules = ['fdgo', 'delaygo', 'contextdm1', 'contextdelaydm1']

    if representation == 'rate':
        fname = 'taskset{:d}_space'.format(setup) + '.pkl'
        fname = os.path.join(model_dir, fname)

        if restore and os.path.isfile(fname):
            print('Reloading results from ' + fname)
            h_trans = tools.load_pickle(fname)
        else:
            tsa = TaskSetAnalysis(model_dir, rules=rules)
            h_trans = tsa.compute_taskspace(rules=rules,
                                            epochs=['stim1'],
                                            dim_reduction_type='PCA',
                                            setup=setup)
            with open(fname, 'wb') as f:
                pickle.dump(h_trans, f)
            print('Results stored at : ' + fname)

    elif representation == 'weight':
        from task import get_rule_index

        model = Model(model_dir)
        hp = model.hp
        n_hidden = hp['n_rnn']
        n_output = hp['n_output']
        with tf.Session() as sess:
            model.restore()
            w_in = sess.run(model.w_in).T

        rule_indices = [get_rule_index(r, hp) for r in rules]
        w_rules = w_in[:, rule_indices]

        from sklearn.decomposition import PCA
        model = PCA(n_components=2)

        # Transform data
        data_trans = model.fit_transform(w_rules.T)

        # Turn into dictionary, and consistent with previous code
        h_trans = OrderedDict()
        for i, r in enumerate(rules):
            # shape will be (1,2), and the key is added an epoch value only for consistency
            h_trans[(r, 'stim1')] = np.array([data_trans[i]])

    else:
        raise ValueError()

    return h_trans
Esempio n. 15
0
def load_data(model_dir=None, sigma_rec=0, lesion_units=None, n_rep=1):
    """Generate model data into standard format.

    Returns:
        data: standard format, list of dict of arrays/dict
            list is over neurons
            dict is for response array and task variable dict
            response array has shape (n_trial, n_time)
    """
    if model_dir is None:
        model_dir = './mantetemp'  # TEMPORARY SETTING

    # Get rules and regressors
    rules = ['contextdm1', 'contextdm2']

    n_rule = len(rules)

    data = list()

    model = Model(model_dir, sigma_rec=sigma_rec)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()
        if lesion_units is not None:
            model.lesion_units(sess, lesion_units)

        # Generate task parameters used
        # Target location
        stim1_loc_list = np.arange(0, 2*np.pi, 2*np.pi/12)
        for stim1_loc in stim1_loc_list:
            params, batch_size = _gen_taskparams(stim1_loc=stim1_loc, n_rep=n_rep)
            stim1_locs_tmp = np.tile(params['stim1_locs'], n_rule)

            x = list() # Network input
            y_loc = list() # Network target output location

            # Start computing the neural activity
            for i, rule in enumerate(rules):
                # Generating task information
                trial = generate_trials(rule, hp, 'psychometric',
                                        params=params, noise_on=True)
                x.append(trial.x)
                y_loc.append(trial.y_loc)

            x = np.concatenate(x, axis=1)
            y_loc = np.concatenate(y_loc, axis=1)

            # Coherences
            stim_mod1_cohs = params['stim1_mod1_strengths'] - params[
                'stim2_mod1_strengths']
            stim_mod2_cohs = params['stim1_mod2_strengths'] - params[
                'stim2_mod2_strengths']
            stim_mod1_cohs /= stim_mod1_cohs.max()
            stim_mod2_cohs /= stim_mod2_cohs.max()

            # Get neural activity
            fetches = [model.h, model.y_hat, model.y_hat_loc]
            H, y_sample, y_sample_loc = sess.run(
                fetches, feed_dict={model.x: x})

            # Downsample in time
            dt_new = 50
            every_t = int(dt_new / hp['dt'])
            # Only analyze the target epoch
            epoch = trial.epochs['stim1']
            H = H[epoch[0]:epoch[1], ...][int(every_t / 2)::every_t, ...]

            # Get performance and choices
            # perfs = get_perf(y_sample, y_loc)
            # y_choice is 1 for choosing stim1_loc, otherwise -1
            y_actual_choice = 2*(get_dist(y_sample_loc[-1]-stim1_loc)<np.pi/2)-1
            y_target_choice = 2*(get_dist(y_loc[-1]-stim1_loc)<np.pi/2)-1

            # Get task variables
            task_var = dict()
            task_var['targ_dir'] = y_actual_choice
            task_var['stim_dir'] = np.tile(stim_mod1_cohs, n_rule)
            task_var['stim_col2dir'] = np.tile(stim_mod2_cohs, n_rule)
            task_var['context'] = np.repeat([1, -1], batch_size)
            task_var['correct'] = (y_actual_choice == y_target_choice).astype(int)
            task_var['stim_dir_sign'] = (task_var['stim_dir']>0).astype(int)*2-1
            task_var['stim_col2dir_sign'] = (task_var['stim_col2dir']>0).astype(int)*2-1


            n_unit = H.shape[-1]
            for i_unit in range(n_unit):
                unit_dict = {
                    'rate': H[:, :, i_unit].T,  # standard format (n_trial, n_time)
                    'task_var': copy.deepcopy(task_var)
                }
                data.append(unit_dict)
    return data
    def lesions(self):
        labels = self.labels

        from network import get_perf
        from task import generate_trials

        # The first will be the intact network
        lesion_units_list = [None]
        for il, l in enumerate(self.unique_labels):
            ind_l = np.where(labels == l)[0]
            # In original indices
            lesion_units_list += [self.ind_active[ind_l]]

        perfs_store_list = list()
        perfs_changes = list()
        cost_store_list = list()
        cost_changes = list()

        for i, lesion_units in enumerate(lesion_units_list):
            model = Model(self.model_dir)
            hp = model.hp
            with tf.Session() as sess:
                model.restore()
                model.lesion_units(sess, lesion_units)

                perfs_store = list()
                cost_store = list()
                for rule in self.rules:
                    n_rep = 16
                    batch_size_test = 256
                    batch_size_test_rep = int(batch_size_test / n_rep)
                    clsq_tmp = list()
                    perf_tmp = list()
                    for i_rep in range(n_rep):
                        trial = generate_trials(rule,
                                                hp,
                                                'random',
                                                batch_size=batch_size_test_rep)
                        feed_dict = tools.gen_feed_dict(model, trial, hp)
                        y_hat_test, c_lsq = sess.run(
                            [model.y_hat, model.cost_lsq], feed_dict=feed_dict)

                        # Cost is first summed over time, and averaged across batch and units
                        # We did the averaging over time through c_mask

                        # IMPORTANT CHANGES: take overall mean
                        perf_test = np.mean(get_perf(y_hat_test, trial.y_loc))
                        clsq_tmp.append(c_lsq)
                        perf_tmp.append(perf_test)

                    perfs_store.append(np.mean(perf_tmp))
                    cost_store.append(np.mean(clsq_tmp))

            perfs_store = np.array(perfs_store)
            cost_store = np.array(cost_store)

            perfs_store_list.append(perfs_store)
            cost_store_list.append(cost_store)

            if i > 0:
                perfs_changes.append(perfs_store - perfs_store_list[0])
                cost_changes.append(cost_store - cost_store_list[0])

        perfs_changes = np.array(perfs_changes)
        cost_changes = np.array(cost_changes)

        return perfs_changes, cost_changes
    def plot_connectivity_byclusters(self):
        """Plot connectivity of the model"""

        ind_active = self.ind_active

        # Sort data by labels and by input connectivity
        model = Model(self.model_dir)
        hp = model.hp
        with tf.Session() as sess:
            model.restore()
            w_in = sess.run(model.w_in).T
            w_rec = sess.run(model.w_rec).T
            w_out = sess.run(model.w_out).T
            b_rec = sess.run(model.b_rec)
            b_out = sess.run(model.b_out)

        w_rec = w_rec[ind_active, :][:, ind_active]
        w_in = w_in[ind_active, :]
        w_out = w_out[:, ind_active]
        b_rec = b_rec[ind_active]

        # nx, nh, ny = hp['shape']
        nr = hp['n_eachring']

        sort_by = 'w_in'
        if sort_by == 'w_in':
            w_in_mod1 = w_in[:, 1:nr + 1]
            w_in_mod2 = w_in[:, nr + 1:2 * nr + 1]
            w_in_modboth = w_in_mod1 + w_in_mod2
            w_prefs = np.argmax(w_in_modboth, axis=1)
        elif sort_by == 'w_out':
            w_prefs = np.argmax(w_out[1:], axis=0)

        # sort by labels then by prefs
        ind_sort = np.lexsort((w_prefs, self.labels))

        ######################### Plotting Connectivity ###############################
        nx = self.hp['n_input']
        ny = self.hp['n_output']
        nh = len(self.ind_active)
        nr = self.hp['n_eachring']
        nrule = len(self.hp['rules'])

        # Plot active units
        _w_rec = w_rec[ind_sort, :][:, ind_sort]
        _w_in = w_in[ind_sort, :]
        _w_out = w_out[:, ind_sort]
        _b_rec = b_rec[ind_sort, np.newaxis]
        _b_out = b_out[:, np.newaxis]
        labels = self.labels[ind_sort]

        l = 0.3
        l0 = (1 - 1.5 * l) / nh

        plot_infos = [
            (_w_rec, [l, l, nh * l0, nh * l0]),
            (_w_in[:, [0]], [l - (nx + 15) * l0, l, 1 * l0,
                             nh * l0]),  # Fixation input
            (_w_in[:, 1:nr + 1], [l - (nx + 11) * l0, l, nr * l0,
                                  nh * l0]),  # Mod 1 stimulus
            (_w_in[:, nr + 1:2 * nr + 1],
             [l - (nx - nr + 8) * l0, l, nr * l0, nh * l0]),  # Mod 2 stimulus
            (_w_in[:, 2 * nr + 1:],
             [l - (nx - 2 * nr + 5) * l0, l, nrule * l0,
              nh * l0]),  # Rule inputs
            (_w_out[[0], :], [l, l - (4) * l0, nh * l0, 1 * l0]),
            (_w_out[1:, :], [l, l - (ny + 6) * l0, nh * l0, (ny - 1) * l0]),
            (_b_rec, [l + (nh + 6) * l0, l, l0, nh * l0]),
            (_b_out, [l + (nh + 6) * l0, l - (ny + 6) * l0, l0, ny * l0])
        ]

        # cmap = sns.diverging_palette(220, 10, sep=80, as_cmap=True)
        cmap = 'coolwarm'
        fig = plt.figure(figsize=(6, 6))
        for plot_info in plot_infos:
            ax = fig.add_axes(plot_info[1])
            vmin, vmid, vmax = np.percentile(plot_info[0].flatten(),
                                             [5, 50, 95])
            _ = ax.imshow(plot_info[0],
                          interpolation='nearest',
                          cmap=cmap,
                          aspect='auto',
                          vmin=vmid - (vmax - vmin) / 2,
                          vmax=vmid + (vmax - vmin) / 2)
            ax.axis('off')

        ax1 = fig.add_axes([l, l + nh * l0, nh * l0, 6 * l0])
        ax2 = fig.add_axes([l - 6 * l0, l, 6 * l0, nh * l0])
        for il, l in enumerate(self.unique_labels):
            ind_l = np.where(labels == l)[0][[0, -1]] + np.array([0, 1])
            ax1.plot(ind_l, [0, 0],
                     linewidth=2,
                     solid_capstyle='butt',
                     color=kelly_colors[il + 1])
            ax2.plot([0, 0],
                     len(labels) - ind_l,
                     linewidth=2,
                     solid_capstyle='butt',
                     color=kelly_colors[il + 1])
        ax1.set_xlim([0, len(labels)])
        ax2.set_ylim([0, len(labels)])
        ax1.axis('off')
        ax2.axis('off')
        if save:
            plt.savefig('figure/connectivity_by' + self.data_type + '.pdf',
                        transparent=True)
        plt.show()
Esempio n. 18
0
File: train.py Progetto: eiroW/FDM
def train(
    model_dir,
    hp=None,
    max_steps=1e7,
    display_step=500,
    ruleset='mante',
    rule_trains=None,
    rule_prob_map=None,
    seed=0,
    rich_output=False,
    load_dir=None,
    trainables=None,
):
    """Train the network.

    Args:
        model_dir: str, training directory
        hp: dictionary of hyperparameters
        max_steps: int, maximum number of training steps
        display_step: int, display steps
        ruleset: the set of rules to train
        rule_trains: list of rules to train, if None then all rules possible
        rule_prob_map: None or dictionary of relative rule probability
        seed: int, random seed to be used

    Returns:
        model is stored at model_dir/model.ckpt
        training configuration is stored at model_dir/hp.json
    """

    tools.mkdir_p(model_dir)

    # Network parameters
    default_hp = get_default_hp(ruleset)
    if hp is not None:
        default_hp.update(hp)
    hp = default_hp
    hp['seed'] = seed
    hp['rng'] = np.random.RandomState(seed)

    # Rules to train and test. Rules in a set are trained together
    if rule_trains is None:
        # By default, training all rules available to this ruleset
        hp['rule_trains'] = task.rules_dict[ruleset]
    else:
        hp['rule_trains'] = rule_trains
    hp['rules'] = hp['rule_trains']

    # Assign probabilities for rule_trains.
    if rule_prob_map is None:
        rule_prob_map = dict()

    # Turn into rule_trains format
    hp['rule_probs'] = None
    if hasattr(hp['rule_trains'], '__iter__'):
        # Set default as 1.
        rule_prob = np.array(
            [rule_prob_map.get(r, 1.) for r in hp['rule_trains']])
        hp['rule_probs'] = list(rule_prob / np.sum(rule_prob))
    tools.save_hp(hp, model_dir)

    # Build the model
    model = Model(model_dir, hp=hp)

    # Display hp
    for key, val in hp.items():
        print('{:20s} = '.format(key) + str(val))

    # Store results
    log = defaultdict(list)
    log['model_dir'] = model_dir

    # Record time
    t_start = time.time()

    with tf.Session() as sess:
        if load_dir is not None:
            model.restore(load_dir)  # complete restore
        else:
            # Assume everything is restored
            sess.run(tf.global_variables_initializer())

        # Set trainable parameters
        if trainables is None or trainables == 'all':
            var_list = model.var_list  # train everything
        elif trainables == 'input':
            # train all nputs
            var_list = [
                v for v in model.var_list
                if ('input' in v.name) and ('rnn' not in v.name)
            ]
        elif trainables == 'rule':
            # train rule inputs only
            var_list = [v for v in model.var_list if 'rule_input' in v.name]
        else:
            raise ValueError('Unknown trainables')
        model.set_optimizer(var_list=var_list)

        # penalty on deviation from initial weight
        if hp['l2_weight_init'] > 0:
            anchor_ws = sess.run(model.weight_list)
            for w, w_val in zip(model.weight_list, anchor_ws):
                model.cost_reg += (hp['l2_weight_init'] *
                                   tf.nn.l2_loss(w - w_val))

            model.set_optimizer(var_list=var_list)

        # partial weight training
        if ('p_weight_train' in hp and (hp['p_weight_train'] is not None)
                and hp['p_weight_train'] < 1.0):
            for w in model.weight_list:
                w_val = sess.run(w)
                w_size = sess.run(tf.size(w))
                w_mask_tmp = np.linspace(0, 1, w_size)
                hp['rng'].shuffle(w_mask_tmp)
                ind_fix = w_mask_tmp > hp['p_weight_train']
                w_mask = np.zeros(w_size, dtype=np.float32)
                w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
                w_mask = tf.constant(w_mask)
                w_mask = tf.reshape(w_mask, w.shape)
                model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
            model.set_optimizer(var_list=var_list)

        step = 0
        while step * hp['batch_size_train'] <= max_steps:
            try:
                # Validation
                if step % display_step == 0:
                    log['trials'].append(step * hp['batch_size_train'])
                    log['times'].append(time.time() - t_start)
                    log = do_eval(sess, model, log, hp['rule_trains'])
                    #if log['perf_avg'][-1] > model.hp['target_perf']:
                    #check if minimum performance is above target
                    if log['perf_min'][-1] > model.hp['target_perf']:
                        print('Perf reached the target: {:0.2f}'.format(
                            hp['target_perf']))
                        break

                    if rich_output:
                        display_rich_output(model, sess, step, log, model_dir)

                # Training
                rule_train_now = hp['rng'].choice(hp['rule_trains'],
                                                  p=hp['rule_probs'])
                # Generate a random batch of trials.
                # Each batch has the same trial length
                trial = generate_trials(rule_train_now,
                                        hp,
                                        'random',
                                        batch_size=hp['batch_size_train'])

                # Generating feed_dict.
                feed_dict = tools.gen_feed_dict(model, trial, hp)
                sess.run(model.train_step, feed_dict=feed_dict)

                step += 1

            except KeyboardInterrupt:
                print("Optimization interrupted by user")
                break

        print("Optimization finished!")
Esempio n. 19
0
def psychometric_choicefamily_2D(model_dir,
                                 rule,
                                 lesion_units=None,
                                 n_coh=8,
                                 n_stim_loc=20,
                                 coh_range=0.1):
    # Generate task parameters for choice tasks
    # coh_range = 0.2
    # coh_range = 0.05
    cohs = np.linspace(-coh_range, coh_range, n_coh)

    batch_size = n_stim_loc * n_coh**2
    batch_shape = (n_stim_loc, n_coh, n_coh)
    ind_stim_loc, ind_stim_mod1, ind_stim_mod2 = np.unravel_index(
        range(batch_size), batch_shape)

    # Looping target location
    stim1_locs = 2 * np.pi * ind_stim_loc / n_stim_loc
    stim2_locs = (stim1_locs + np.pi) % (2 * np.pi)

    stim_mod1_cohs = cohs[ind_stim_mod1]
    stim_mod2_cohs = cohs[ind_stim_mod2]

    params_dict = dict()
    params_dict['dm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_strengths' : 1 + stim_mod1_cohs, # Just use mod 1 value
          'stim2_strengths' : 1 - stim_mod1_cohs,
          'stim_time'    : 800
          }
    params_dict['dm2'] = params_dict['dm1']

    params_dict['contextdm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod2_cohs,
          'stim2_mod2_strengths' : 1 - stim_mod2_cohs,
          'stim_time'    : 800
          }

    params_dict['contextdm2'] = params_dict['contextdm1']

    params_dict['contextdelaydm1'] = params_dict['contextdm1']
    params_dict['contextdelaydm1']['stim_time'] = 800
    params_dict['contextdelaydm2'] = params_dict['contextdelaydm1']

    params_dict['multidm'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod1_cohs, # Same as Mod 1
          'stim2_mod2_strengths' : 1 - stim_mod1_cohs,
          'stim_time'    : 800
          }

    params_dict['contextdelaydm1'] = \
         {'stim1_locs' : stim1_locs,
          'stim2_locs' : stim2_locs,
          'stim1_mod1_strengths' : 1 + stim_mod1_cohs,
          'stim2_mod1_strengths' : 1 - stim_mod1_cohs,
          'stim1_mod2_strengths' : 1 + stim_mod2_cohs,
          'stim2_mod2_strengths' : 1 - stim_mod2_cohs,
          'stim_time'    : 800
          }

    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()
        model.lesion_units(sess, lesion_units)

        params = params_dict[rule]
        trial = generate_trials(rule, hp, 'psychometric', params=params)
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        y_sample, y_loc_sample = sess.run([model.y_hat, model.y_hat_loc],
                                          feed_dict=feed_dict)

    # Compute the overall performance.
    # Importantly, discard trials where no decision was made
    loc_cor = trial.y_loc[-1]  # last time point, correct locations
    loc_err = (loc_cor + np.pi) % (2 * np.pi)
    choose_cor = (get_dist(y_loc_sample[-1] - loc_cor) < THETA).sum()
    choose_err = (get_dist(y_loc_sample[-1] - loc_err) < THETA).sum()
    perf = choose_cor / (choose_cor + choose_err)

    # Compute the proportion of choosing choice 1 and maintain the batch_shape
    stim1_locs_ = np.reshape(stim1_locs, batch_shape)
    stim2_locs_ = np.reshape(stim2_locs, batch_shape)

    y_loc_sample = np.reshape(y_loc_sample[-1], batch_shape)
    choose1 = (get_dist(y_loc_sample - stim1_locs_) < THETA).sum(axis=0)
    choose2 = (get_dist(y_loc_sample - stim2_locs_) < THETA).sum(axis=0)
    prop1s = choose1 / (choose1 + choose2)

    return perf, prop1s, cohs
Esempio n. 20
0
def pretty_inputoutput_plot(model_dir, rule, save=False, plot_ylabel=False):
    """Plot the input and output activity for a sample trial from one task.

    Args:
        model_dir: model directory
        rule: string, the rule
        save: bool, whether to save plots
        plot_ylabel: bool, whether to plot ylable
    """

    fs = 7

    model = Model(model_dir)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()

        trial = generate_trials(rule, hp, mode='test')
        x, y = trial.x, trial.y
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)

        t_plot = np.arange(x.shape[0]) * hp['dt'] / 1000

        assert hp['num_ring'] == 2

        n_eachring = hp['n_eachring']

        fig = plt.figure(figsize=(1.3, 2))
        ylabels = ['fix. in', 'stim. mod1', 'stim. mod2', 'fix. out', 'out']
        heights = np.array([0.03, 0.2, 0.2, 0.03, 0.2]) + 0.01
        for i in range(5):
            ax = fig.add_axes(
                [0.15,
                 sum(heights[i + 1:] + 0.02) + 0.1, 0.8, heights[i]])
            cmap = 'Purples'
            plt.xticks([])
            ax.tick_params(axis='both',
                           which='major',
                           labelsize=fs,
                           width=0.5,
                           length=2,
                           pad=3)

            if plot_ylabel:
                ax.spines["right"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.spines["top"].set_visible(False)
                ax.xaxis.set_ticks_position('bottom')
                ax.yaxis.set_ticks_position('left')

            else:
                ax.spines["left"].set_visible(False)
                ax.spines["right"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.spines["top"].set_visible(False)
                ax.xaxis.set_ticks_position('none')

            if i == 0:
                plt.plot(t_plot, x[:, 0, 0], color='xkcd:blue')
                if plot_ylabel:
                    plt.yticks([0, 1], ['', ''], rotation='vertical')
                plt.ylim([-0.1, 1.5])
                plt.title(rule_name[rule], fontsize=fs)
            elif i == 1:
                plt.imshow(x[:, 0, 1:1 + n_eachring].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')
                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
            elif i == 2:
                plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')

                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
            elif i == 3:
                plt.plot(t_plot, y[:, 0, 0], color='xkcd:green')
                plt.plot(t_plot, y_hat[:, 0, 0], color='xkcd:blue')
                if plot_ylabel:
                    plt.yticks([0.05, 0.8], ['', ''], rotation='vertical')
                plt.ylim([-0.1, 1.1])
            elif i == 4:
                plt.imshow(y_hat[:, 0, 1:].T,
                           aspect='auto',
                           cmap=cmap,
                           vmin=0,
                           vmax=1,
                           interpolation='none',
                           origin='lower')
                if plot_ylabel:
                    plt.yticks(
                        [0, (n_eachring - 1) / 2, n_eachring - 1],
                        [r'0$\degree$', r'180$\degree$', r'360$\degree$'],
                        rotation='vertical')
                plt.xticks([0, y_hat.shape[0]], ['0', '2'])
                plt.xlabel('Time (s)', fontsize=fs, labelpad=-3)
                ax.spines["bottom"].set_visible(True)

            if plot_ylabel:
                plt.ylabel(ylabels[i], fontsize=fs)
            else:
                plt.yticks([])
            ax.get_yaxis().set_label_coords(-0.12, 0.5)

        if save:
            save_name = 'figure/sample_' + rule_name[rule].replace(' ',
                                                                   '') + '.pdf'
            plt.savefig(save_name, transparent=True)
        plt.show()
Esempio n. 21
0
    def plot_rule_connections(self):
        """Plot connectivity while sorting by group.

        Args:
            conn_type: str, type of connectivity to plot.
        """

        # Sort data by labels and by input connectivity
        model = Model(self.model_dir)
        hp = model.hp
        with tf.Session() as sess:
            model.restore()
            w_in = sess.run(model.w_in)
        w_in = w_in.T

        groups = ['1', '2', '12']

        # Plot input rule connectivity
        rules = ['contextdm1', 'contextdm2', 'dm1', 'dm2', 'multidm']

        w_stores = OrderedDict()
        w_all_stores = OrderedDict()
        pos = list()
        width = 0.15
        colors = list()
        for i_group, group in enumerate(groups):
            w_store_tmp = list()
            ind = self.group_ind_orig[group]
            for i_rule, rule in enumerate(rules):
                ind_rule = get_rule_index(rule, hp)
                w_conn = w_in[ind, ind_rule].mean(axis=0)
                w_store_tmp.append(w_conn)
                w_all_stores[(group, rule)] = w_in[ind, ind_rule].flatten()
                pos.append(i_rule+(i_group-1.5)*width)
                colors.append(self.colors[group])
            w_stores[group] = w_store_tmp


        fs = 6
        fig = plt.figure(figsize=(2.5, 1.2))
        ax = fig.add_axes([0.17,0.45,0.8,0.4])
        # for i, group in enumerate(groups):
        #     x = np.arange(len(rules))+(i-1.5)*width
        #     b0 = ax.bar(x, w_stores[group],
        #                 width=width, color=self.colors[group], edgecolor='none')

        bp = ax.boxplot([w for w in w_all_stores.values()], notch=True, sym='',
                        bootstrap=10000,
                        showcaps=False, patch_artist=True, widths=width, positions=pos,
                        whiskerprops={'linewidth': 1.0})
        # for element in ['boxes', 'whiskers', 'fliers']:
        #     plt.setp(bp[element], color='xkcd:cerulean')

        for patch, c in zip(bp['boxes'], colors):
            plt.setp(patch, color=c)
        for i_whisker, patch in enumerate(bp['whiskers']):
            plt.setp(patch, color=colors[int(i_whisker/2)])
        for element in ['means', 'medians']:
            plt.setp(bp[element], color='white')

        ax.set_xticks(np.arange(len(rules)))
        ax.set_xticklabels([rule_name[r] for r in rules], rotation=25)
        ax.set_xlabel('Input from rule units', fontsize=fs, labelpad=3)
        ax.set_ylabel('Conn. weight', fontsize=fs)
        # lg = ax.legend(groups, fontsize=fs, ncol=3, bbox_to_anchor=(1,1.5),
        #                labelspacing=0.2, loc=1, frameon=False, title='To group')
        # plt.setp(lg.get_title(),fontsize=fs)
        ax.tick_params(axis='both', which='major', labelsize=fs)
        plt.locator_params(axis='y',nbins=2)
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        ax.set_xlim([-0.8, len(rules)-0.2])
        ax.plot([-0.5, len(rules)-0.5], [0, 0], color='gray', linewidth=0.5)
        plt.savefig('figure/conn_rule_contextdm.pdf', transparent=True)
        plt.show()
Esempio n. 22
0
def schematic_plot(model_dir, rule=None):
    fontsize = 6

    rule = rule or 'dm1'

    model = Model(model_dir, dt=1)
    hp = model.hp

    with tf.Session() as sess:
        model.restore()
        trial = generate_trials(rule, hp, mode='test')
        feed_dict = tools.gen_feed_dict(model, trial, hp)
        x = trial.x
        h, y_hat = sess.run([model.h, model.y_hat], feed_dict=feed_dict)

    n_eachring = hp['n_eachring']
    n_hidden = hp['n_rnn']

    # Plot Stimulus
    fig = plt.figure(figsize=(1.0, 1.2))
    heights = np.array([0.06, 0.25, 0.25])
    for i in range(3):
        ax = fig.add_axes(
            [0.2, sum(heights[i + 1:] + 0.1) + 0.05, 0.7, heights[i]])
        cmap = 'Purples'
        plt.xticks([])

        # Fixed style for these plots
        ax.tick_params(axis='both',
                       which='major',
                       labelsize=fontsize,
                       width=0.5,
                       length=2,
                       pad=3)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        if i == 0:
            plt.plot(x[:, 0, 0], color='xkcd:blue')
            plt.yticks([0, 1], ['', ''], rotation='vertical')
            plt.ylim([-0.1, 1.5])
            plt.title('Fixation input', fontsize=fontsize, y=0.9)
        elif i == 1:
            plt.imshow(x[:, 0, 1:1 + n_eachring].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1],
                       [r'0$\degree$', '', r'360$\degree$'],
                       rotation='vertical')
            plt.title('Stimulus mod 1', fontsize=fontsize, y=0.9)
        elif i == 2:
            plt.imshow(x[:, 0, 1 + n_eachring:1 + 2 * n_eachring].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1], ['', '', ''],
                       rotation='vertical')
            plt.title('Stimulus mod 2', fontsize=fontsize, y=0.9)
        ax.get_yaxis().set_label_coords(-0.12, 0.5)
    plt.savefig('figure/schematic_input.pdf', transparent=True)
    plt.show()

    # Plot Rule Inputs
    fig = plt.figure(figsize=(1.0, 0.5))
    ax = fig.add_axes([0.2, 0.3, 0.7, 0.45])
    cmap = 'Purples'
    X = x[:, 0, 1 + 2 * n_eachring:]
    plt.imshow(X.T,
               aspect='auto',
               vmin=0,
               vmax=1,
               cmap=cmap,
               interpolation='none',
               origin='lower')

    plt.xticks([0, X.shape[0]])
    ax.set_xlabel('Time (ms)', fontsize=fontsize, labelpad=-5)

    # Fixed style for these plots
    ax.tick_params(axis='both',
                   which='major',
                   labelsize=fontsize,
                   width=0.5,
                   length=2,
                   pad=3)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_linewidth(0.5)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    plt.yticks([0, X.shape[-1] - 1], ['1', str(X.shape[-1])],
               rotation='vertical')
    plt.title('Rule inputs', fontsize=fontsize, y=0.9)
    ax.get_yaxis().set_label_coords(-0.12, 0.5)

    plt.savefig('figure/schematic_rule.pdf', transparent=True)
    plt.show()

    # Plot Units
    fig = plt.figure(figsize=(1.0, 0.8))
    ax = fig.add_axes([0.2, 0.1, 0.7, 0.75])
    cmap = 'Purples'
    plt.xticks([])
    # Fixed style for these plots
    ax.tick_params(axis='both',
                   which='major',
                   labelsize=fontsize,
                   width=0.5,
                   length=2,
                   pad=3)
    ax.spines["left"].set_linewidth(0.5)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    plt.imshow(h[:, 0, :].T,
               aspect='auto',
               cmap=cmap,
               vmin=0,
               vmax=1,
               interpolation='none',
               origin='lower')
    plt.yticks([0, n_hidden - 1], ['1', str(n_hidden)], rotation='vertical')
    plt.title('Recurrent units', fontsize=fontsize, y=0.95)
    ax.get_yaxis().set_label_coords(-0.12, 0.5)
    plt.savefig('figure/schematic_units.pdf', transparent=True)
    plt.show()

    # Plot Outputs
    fig = plt.figure(figsize=(1.0, 0.8))
    heights = np.array([0.1, 0.45]) + 0.01
    for i in range(2):
        ax = fig.add_axes(
            [0.2, sum(heights[i + 1:] + 0.15) + 0.1, 0.7, heights[i]])
        cmap = 'Purples'
        plt.xticks([])

        # Fixed style for these plots
        ax.tick_params(axis='both',
                       which='major',
                       labelsize=fontsize,
                       width=0.5,
                       length=2,
                       pad=3)
        ax.spines["left"].set_linewidth(0.5)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')

        if i == 0:
            plt.plot(y_hat[:, 0, 0], color='xkcd:blue')
            plt.yticks([0.05, 0.8], ['', ''], rotation='vertical')
            plt.ylim([-0.1, 1.1])
            plt.title('Fixation output', fontsize=fontsize, y=0.9)

        elif i == 1:
            plt.imshow(y_hat[:, 0, 1:].T,
                       aspect='auto',
                       cmap=cmap,
                       vmin=0,
                       vmax=1,
                       interpolation='none',
                       origin='lower')
            plt.yticks([0, (n_eachring - 1) / 2, n_eachring - 1],
                       [r'0$\degree$', '', r'360$\degree$'],
                       rotation='vertical')
            plt.xticks([])
            plt.title('Response', fontsize=fontsize, y=0.9)

        ax.get_yaxis().set_label_coords(-0.12, 0.5)

    plt.savefig('figure/schematic_outputs.pdf', transparent=True)
    plt.show()
Esempio n. 23
0
def pretty_singleneuron_plot(model_dir,
                             rules,
                             neurons,
                             epoch=None,
                             save=False,
                             ylabel_firstonly=True,
                             trace_only=False,
                             plot_stim_avg=False,
                             save_name=''):
    """Plot the activity of a single neuron in time across many trials

    Args:
        model_dir:
        rules: rules to plot
        neurons: indices of neurons to plot
        epoch: epoch to plot
        save: save figure?
        ylabel_firstonly: if True, only plot ylabel for the first rule in rules
    """

    if isinstance(rules, str):
        rules = [rules]

    try:
        _ = iter(neurons)
    except TypeError:
        neurons = [neurons]

    h_tests = dict()
    model = Model(model_dir)
    hp = model.hp
    with tf.Session() as sess:
        model.restore()

        t_start = int(500 / hp['dt'])

        for rule in rules:
            # Generate a batch of trial from the test mode
            trial = generate_trials(rule, hp, mode='test')
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            h = sess.run(model.h, feed_dict=feed_dict)
            h_tests[rule] = h

    for neuron in neurons:
        h_max = np.max([h_tests[r][t_start:, :, neuron].max() for r in rules])
        for j, rule in enumerate(rules):
            fs = 6
            fig = plt.figure(figsize=(1.0, 0.8))
            ax = fig.add_axes([0.35, 0.25, 0.55, 0.55])
            t_plot = np.arange(
                h_tests[rule][t_start:].shape[0]) * hp['dt'] / 1000
            _ = ax.plot(t_plot,
                        h_tests[rule][t_start:, :, neuron],
                        lw=0.5,
                        color='gray')

            if plot_stim_avg:
                # Plot stimulus averaged trace
                _ = ax.plot(np.arange(h_tests[rule][t_start:].shape[0]) *
                            hp['dt'] / 1000,
                            h_tests[rule][t_start:, :, neuron].mean(axis=1),
                            lw=1,
                            color='black')

            if epoch is not None:
                e0, e1 = trial.epochs[epoch]
                e0 = e0 if e0 is not None else 0
                e1 = e1 if e1 is not None else h_tests[rule].shape[0]
                ax.plot([e0, e1], [h_max * 1.15] * 2,
                        color='black',
                        linewidth=1.5)
                figname = 'figure/trace_' + rule_name[
                    rule] + epoch + save_name + '.pdf'
            else:
                figname = 'figure/trace_unit' + str(
                    neuron) + rule_name[rule] + save_name + '.pdf'

            plt.ylim(np.array([-0.1, 1.2]) * h_max)
            plt.xticks([0, np.floor(np.max(t_plot) + 0.01)])
            plt.xlabel('Time (s)', fontsize=fs, labelpad=-5)
            plt.locator_params(axis='y', nbins=4)
            if j > 0 and ylabel_firstonly:
                ax.set_yticklabels([])
            else:
                plt.ylabel('Activitity (a.u.)', fontsize=fs, labelpad=2)
            plt.title('Unit {:d} '.format(neuron) + rule_name[rule],
                      fontsize=5)
            ax.tick_params(axis='both', which='major', labelsize=fs)
            ax.spines["right"].set_visible(False)
            ax.spines["top"].set_visible(False)
            ax.xaxis.set_ticks_position('bottom')
            ax.yaxis.set_ticks_position('left')
            if trace_only:
                ax.spines["left"].set_visible(False)
                ax.spines["bottom"].set_visible(False)
                ax.xaxis.set_ticks_position('none')
                ax.set_xlabel('')
                ax.set_ylabel('')
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_title('')

            if save:
                plt.savefig(figname, transparent=True)
            plt.show()