def get_link(pkl_p='', selected_p=''):
    if pkl_p:
        try:
            link = load_pickle(pkl_p)
        except FileNotFoundError:
            print(f'No {pkl_p}')
        else:
            return link
    if selected_p:
        return img_search(selected_p)
Пример #2
0
def compute_prediction(fn, model):
    border = int((h_neigh-1)/2)

    img = cv2.imread(fn, cv2.IMREAD_COLOR)
    cached_fn = os.path.join(cache_dir, fn.split(
        os.path.sep)[-1].split('.')[0]+".pkl")
    if os.path.isfile(cached_fn):
        features = load_pickle(cached_fn)
    else:
        features = train.create_features(img)
        dump2pickle(features, cached_fn)
    scaler = load_pickle(os.path.join(model_dir, scaler_fn))
    features = features.reshape(-1, features.shape[1])
    features = scaler.transform(features)
    model_predictions = model.predict_proba(features)
    model_predictions = prob2class(model_predictions)
    predictions_image = model_predictions.reshape(
        [img.shape[0]-2*border, img.shape[1]-2*border, -1])
    return predictions_image
def weird_load(fp):
    container = {}
    with open(fp, 'rb') as file:
        while True:
            chunk = file.read(3408)
            if not chunk:
                break
            with open('_', 'wb') as temp:
                temp.write(chunk)
            data = load_pickle('_')
            container.update(data)
    return container
Пример #4
0
    def __init__(self):

        #SpaCy Initializers
        self.nlp = spacy.load('de_core_news_sm')
        self.doc = self.nlp('Ich schicke meiner Mutter ein Geschenk')
        self.matcher = Matcher(self.nlp.vocab)

        #Dictionaries
        self.worter = load_pickle("Worter.p")

        #Patterns
        self.patterns = get_patterns()
        self.pattern_akk = [[{
            'TAG': 'ART'
        }, {
            'TAG': 'ADJA',
            'OP': '*'
        }, {
            'DEP': 'oa'
        }]]
        self.pattern_nom = [[{
            'TAG': 'ART'
        }, {
            'TAG': 'ADJA',
            'OP': '*'
        }, {
            'DEP': 'sb'
        }]]
        self.pattern_dat = [[{
            'TAG': 'ART'
        }, {
            'TAG': 'ADJA',
            'OP': '*'
        }, {
            'DEP': 'da'
        }]]

        #Add patterns

        self.matcher.add('Akk', self.pattern_akk)
        self.matcher.add('Nom', self.pattern_nom)
        self.matcher.add('Dat', self.pattern_dat)

        #Variables
        self.found_matches = None
Пример #5
0
    def __init__(self, model_dir):
        """Analyze based on units."""
        data_type  = 'rule'
        fname = os.path.join(model_dir, 'variance_' + data_type + '.pkl')
        res = tools.load_pickle(fname)
        h_var_all = res['h_var_all']
        keys      = res['keys']

        rules = ['contextdm1', 'contextdm2']
        ind_rules = [keys.index(rule) for rule in rules]
        h_var_all = h_var_all[:, ind_rules]

        # First only get active units. Total variance across tasks larger than 1e-3
        ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0]
        # ind_active = np.where(h_var_all.sum(axis=1) > 1e-1)[0] # TEMPORARY
        h_var_all  = h_var_all[ind_active, :]

        # Normalize by the total variance across tasks
        h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T

        group_ind = dict()

        group_ind['1'] = np.where(h_normvar_all[:,0] > 0.9)[0]
        group_ind['2'] = np.where(h_normvar_all[:, 0] < 0.1)[0]
        group_ind['12'] = np.where(np.logical_and(h_normvar_all[:,0] > 0.4,
                                                  h_normvar_all[:,0] < 0.6))[0]
        group_ind['1+2'] = np.concatenate((group_ind['1'], group_ind['2']))

        group_ind_orig = {key: ind_active[val] for key, val in group_ind.items()}

        self.model_dir = model_dir
        self.group_ind = group_ind
        self.group_ind_orig = group_ind_orig
        self.h_normvar_all = h_normvar_all
        self.rules = rules
        self.ind_active = ind_active
        colors = ['xkcd:'+c for c in ['orange', 'green', 'pink', 'sky blue']]
        self.colors = dict(zip([None, '1', '2', '12'], colors))
        self.lesion_group_names = {None : 'Intact',
                                   '1'  : 'Lesion group 1',
                                   '2'  : 'Lesion group 2',
                                   '12' : 'Lesion group 12',
                                   '1+2': 'Lesion group 1 & 2'}
Пример #6
0
def _compute_hist_varprop(model_dir, rule_pair, random_rotation=False):
    data_type = 'rule'
    assert len(rule_pair) == 2
    assert data_type == 'rule'

    fname = os.path.join(model_dir, 'variance_' + data_type)
    if random_rotation:
        fname += '_rr'
    fname += '.pkl'
    if not os.path.isfile(fname):
        # If not computed, compute now
        compute_variance(model_dir, random_rotation=random_rotation)

    res = tools.load_pickle(fname)
    h_var_all = res['h_var_all']
    keys = res['keys']

    ind_rules = [keys.index(rule) for rule in rule_pair]
    h_var_all = h_var_all[:, ind_rules]

    # First only get active units. Total variance across tasks larger than 1e-3
    ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0]

    # Temporary: Mimicking biased sampling. Notice the free parameter though.
    # print('Mimicking selective sampling')
    # ind_active = np.where((h_var_all.sum(axis=1) > 1e-3)*(h_var_all[:,0]>1*1e-2))[0]

    h_var_all = h_var_all[ind_active, :]

    # Normalize by the total variance across tasks
    h_normvar_all = (h_var_all.T / np.sum(h_var_all, axis=1)).T

    # Plot the proportion of variance for the first rule
    # data_plot = h_normvar_all[:, 0]
    data_plot = (h_var_all[:, 0] - h_var_all[:, 1]) / (
        (h_var_all[:, 0] + h_var_all[:, 1]))
    hist, bins_edge = np.histogram(data_plot, bins=20, range=(-1, 1))

    # # Plot the percentage instead of the total count
    # hist = hist/np.sum(hist)

    return hist, bins_edge
Пример #7
0
def compute_replacerule_performance(model_dir, setup, restore=False):
    """Compute the performance of one task given a replaced rule input."""

    if setup == 1:
        rule = 'delayanti'
        replace_rule = np.array(['delayanti', 'fdanti', 'delaygo', 'fdgo'])

        rule_strengths = \
            [[1,0,0,0],
             [0,1,0,0],
             [0,1,1,0],
             [0,1,1,-1]]

    elif setup == 2:
        rule = 'contextdelaydm1'
        replace_rule = np.array(
            ['contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2'])

        rule_strengths = \
            [[1,0,0,0],
             [0,1,0,0],
             [0,1,1,0],
             [0,0,1,0],
             [0,1,1,-1]]

    elif setup == 3:
        rule = 'dmsgo'
        replace_rule = np.array(['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo'])
        rule_strengths = \
            [[1,0,0,0],
             [0,1,0,0],
             [0,1,1,0],
             [0,1,1,-1]]

    else:
        raise ValueError('Unknown setup value')

    fname = 'taskset{:d}_perf'.format(setup) + '.pkl'
    fname = os.path.join(model_dir, fname)

    if restore and os.path.isfile(fname):
        print('Reloading results from ' + fname)
        r = tools.load_pickle(fname)
        perfs, rule, names = r['perfs'], r['rule'], r['names']

    else:
        perfs = list()
        names = list()
        for rule_strength in rule_strengths:
            perf, _ = run_network_replacerule(model_dir, rule, replace_rule,
                                              rule_strength)
            perfs.append(perf)
            names.append(replace_rule_name(replace_rule, rule_strength))

        perfs = np.array(perfs)
        print(perfs)

        results = {'perfs': perfs, 'rule': rule, 'names': names}
        with open(fname, 'wb') as f:
            pickle.dump(results, f)
        print('Results stored at : ' + fname)

    return perfs, rule, names
    def generate_neuron_info(self, 
                            epochs, 
                            trial_list = None, 
                            rules = None, 
                            norm = True, 
                            p_value = 0.05, 
                            abs_active_thresh = 1e-3,):
    
        self.neuron_info = OrderedDict()

        if trial_list is None:
            trial_list = self.trial_list
        if rules is None:
            rules = self.rules

        for rule in rules:
            for epoch in epochs:
                if epoch not in self.epoch_info[rule].keys():
                    raise KeyError('Rule ',rule,' dose not have epoch ',epoch,'!')
    
        for trial_num in trial_list:
            self.neuron_info[trial_num] = OrderedDict()

            for rule in rules:
                H = tools.load_pickle(self.model_dir+'/'+str(trial_num)+'/'+'H_'+rule+'.pkl')
                self.neuron_info[trial_num][rule] = OrderedDict()

                for epoch in epochs:
                    self.neuron_info[trial_num][rule][epoch] = OrderedDict()

                    for info_type in ['selective_neurons','active_neurons','exh_neurons','inh_neurons','mix_neurons',\
                        'firerate_loc_order','firerate_max_central']:
                        self.neuron_info[trial_num][rule][epoch][info_type] = list()

                    for neuron in range(self.neurons):
                        neuron_data_abs = OrderedDict()
                        neuron_data_norm = OrderedDict()
                        neuron_data = OrderedDict()
                        firerate_abs = list()
                        firerate_norm = list()
                        firerate = list()

                        for loc in self.in_loc_set[rule]:
                            fix_level = H[self.epoch_info[rule]['fix1'][0]:self.epoch_info[rule]['fix1'][1], \
                                self.in_loc[rule] == loc, neuron].mean(axis=1).mean(axis=0)
                            neuron_data_abs[loc] = H[self.epoch_info[rule][epoch][0]:self.epoch_info[rule][epoch][1],\
                                self.in_loc[rule] == loc, neuron].mean(axis=0) #axis = 1 for trial-wise mean, 0 for time-wise mean
                            neuron_data_norm[loc] = neuron_data_abs[loc]/fix_level-1
                            if norm:
                                neuron_data[loc] = neuron_data_norm[loc]
                            else:
                                neuron_data[loc] = neuron_data_abs[loc]
                            firerate_abs.append(neuron_data_abs[loc].mean())
                            firerate_norm.append(neuron_data_norm[loc].mean())
                            firerate.append(neuron_data[loc].mean())

                        data_frame = pd.DataFrame(neuron_data)
                        data_frame_melt = data_frame.melt()
                        data_frame_melt.columns = ['Location','Fire_rate']
                        model = ols('Fire_rate~C(Location)',data=data_frame_melt).fit()
                        anova_table = anova_lm(model, typ = 2)

                        if max(firerate_abs) > abs_active_thresh:
                            self.neuron_info[trial_num][rule][epoch]['active_neurons'].append(neuron)

                            if anova_table['PR(>F)'][0] <= p_value:
                                self.neuron_info[trial_num][rule][epoch]['selective_neurons'].append(neuron)

                                if max(firerate_norm) < 0:
                                    self.neuron_info[trial_num][rule][epoch]['inh_neurons'].append(neuron)
                                elif min(firerate_norm) >= 0:
                                #else:
                                    self.neuron_info[trial_num][rule][epoch]['exh_neurons'].append(neuron)
                                else:
                                    self.neuron_info[trial_num][rule][epoch]['mix_neurons'].append(neuron)

                        max_index = firerate.index(max(firerate))
                        temp_len = len(firerate)
                        if temp_len%2 == 0:
                            mc_len = temp_len + 1
                        else:
                            mc_len = temp_len

                        firerate_max_central = np.zeros(mc_len)
                        for i in range(temp_len):
                            new_index = (i-max_index+temp_len//2)%temp_len
                            firerate_max_central[new_index] = firerate[i]
                        if temp_len%2 == 0:
                            firerate_max_central[-1] = firerate_max_central[0]

                        self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'].append(firerate)
                        self.neuron_info[trial_num][rule][epoch]['firerate_max_central'].append(firerate_max_central)

                    self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'] = \
                        np.array(self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'])
                    self.neuron_info[trial_num][rule][epoch]['firerate_max_central'] = \
                        np.array(self.neuron_info[trial_num][rule][epoch]['firerate_max_central'])
Пример #9
0
                 926, 928, 1455, 1811, 1816, 1819, 2372, 2765, 3726, 7086,
                 7358, 7504, 7516, 8231, 8235, 8381, 8393, 9108
             ]]
    input_dims = [
        538, 529, 358, 13, 7, 2, 306, 7, 876, 7, 175, 7, 272, 20, 926, 2, 527,
        356, 5, 3, 553, 393, 961, 3360, 272, 146, 12, 715, 4, 146, 12, 715, 4
    ]

    batch_size = 256
    data_path = 'feature/'
    samp = 'new_hard'
    nb_epoch = 20

    job_train = load_npz(data_path + samp + '_train_job.npz')
    res_train = load_npz(data_path + samp + '_train_res.npz')
    gt_train = np.asarray(load_pickle(data_path + samp + '_train_label.pkl'))
    job_val = load_npz(data_path + samp + '_val_job.npz')
    res_val = load_npz(data_path + samp + '_val_res.npz')
    gt_val = np.asarray(load_pickle(data_path + samp + '_val_label.pkl'))

    print('[INFO] Compiling model...')
    model = iPNN(input_dims=input_dims,
                 embedding_dim=128,
                 prod_dim=128,
                 hidden_dims=[128, 64, 64, 32])
    model.compile(loss='binary_crossentropy',
                  optimizer='adam',
                  metrics=['acc', 'mae'])
    model.summary()

    print('[INFO] Initiating training')
Пример #10
0
def compute_taskspace(model_dir, setup, restore=False, representation='rate'):
    if setup == 1:
        rules = ['fdgo', 'fdanti', 'delaygo', 'delayanti']
    elif setup == 2:
        rules = [
            'contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2'
        ]
    elif setup == 3:
        rules = ['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo']
    elif setup == 4:
        rules = [
            'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'contextdm1',
            'contextdm2', 'multidm'
        ]
    elif setup == 5:
        rules = [
            'contextdelaydm1',
            'contextdelaydm2',
            'multidelaydm',
            'delaydm1',
            'delaydm2',
            'contextdm1',
            'contextdm2',
            'multidm',
            'dm1',
            'dm2',
        ]
    elif setup == 6:
        rules = ['fdgo', 'delaygo', 'contextdm1', 'contextdelaydm1']

    if representation == 'rate':
        fname = 'taskset{:d}_space'.format(setup) + '.pkl'
        fname = os.path.join(model_dir, fname)

        if restore and os.path.isfile(fname):
            print('Reloading results from ' + fname)
            h_trans = tools.load_pickle(fname)
        else:
            tsa = TaskSetAnalysis(model_dir, rules=rules)
            h_trans = tsa.compute_taskspace(rules=rules,
                                            epochs=['stim1'],
                                            dim_reduction_type='PCA',
                                            setup=setup)
            with open(fname, 'wb') as f:
                pickle.dump(h_trans, f)
            print('Results stored at : ' + fname)

    elif representation == 'weight':
        from task import get_rule_index

        model = Model(model_dir)
        hp = model.hp
        n_hidden = hp['n_rnn']
        n_output = hp['n_output']
        with tf.Session() as sess:
            model.restore()
            w_in = sess.run(model.w_in).T

        rule_indices = [get_rule_index(r, hp) for r in rules]
        w_rules = w_in[:, rule_indices]

        from sklearn.decomposition import PCA
        model = PCA(n_components=2)

        # Transform data
        data_trans = model.fit_transform(w_rules.T)

        # Turn into dictionary, and consistent with previous code
        h_trans = OrderedDict()
        for i, r in enumerate(rules):
            # shape will be (1,2), and the key is added an epoch value only for consistency
            h_trans[(r, 'stim1')] = np.array([data_trans[i]])

    else:
        raise ValueError()

    return h_trans
Пример #11
0
def process_trajectories(file_name, data_dir):
    data = load_pickle(DATA_DIR, file_name)
    import ipdb
    ipdb.set_trace()
Пример #12
0
    def __init__(self, model_dir, data_type, normalization_method='max'):
        hp = tools.load_hp(model_dir)

        # If not computed, use variance.py
        fname = os.path.join(model_dir, 'variance_' + data_type + '.pkl')
        res = tools.load_pickle(fname)
        h_var_all_ = res['h_var_all']
        self.keys = res['keys']

        # First only get active units. Total variance across tasks larger than 1e-3
        # ind_active = np.where(h_var_all_.sum(axis=1) > 1e-2)[0]
        ind_active = np.where(h_var_all_.sum(axis=1) > 1e-3)[0]
        h_var_all = h_var_all_[ind_active, :]

        # Normalize by the total variance across tasks
        if normalization_method == 'sum':
            h_normvar_all = (h_var_all.T / np.sum(h_var_all, axis=1)).T
        elif normalization_method == 'max':
            h_normvar_all = (h_var_all.T / np.max(h_var_all, axis=1)).T
        elif normalization_method == 'none':
            h_normvar_all = h_var_all
        else:
            raise NotImplementedError()

        ################################## Clustering ################################
        from sklearn import metrics
        X = h_normvar_all

        # Clustering
        from sklearn.cluster import AgglomerativeClustering, KMeans

        # Choose number of clusters that maximize silhouette score
        n_clusters = range(2, 30)
        scores = list()
        labels_list = list()
        for n_cluster in n_clusters:
            # clustering = AgglomerativeClustering(n_cluster, affinity='cosine', linkage='average')
            clustering = KMeans(n_cluster,
                                algorithm='full',
                                n_init=20,
                                random_state=0)
            clustering.fit(
                X)  # n_samples, n_features = n_units, n_rules/n_epochs
            labels = clustering.labels_  # cluster labels

            score = metrics.silhouette_score(X, labels)

            scores.append(score)
            labels_list.append(labels)

        scores = np.array(scores)

        # Heuristic elbow method
        # Choose the number of cluster when Silhouette score first falls
        # Choose the number of cluster when Silhouette score is maximum
        if data_type == 'rule':
            #i = np.where((scores[1:]-scores[:-1])<0)[0][0]
            i = np.argmax(scores)
        else:
            # The more rigorous method doesn't work well in this case
            i = n_clusters.index(10)

        labels = labels_list[i]
        n_cluster = n_clusters[i]
        print('Choosing {:d} clusters'.format(n_cluster))

        # Sort clusters by its task preference (important for consistency across nets)
        if data_type == 'rule':
            label_prefs = [
                np.argmax(h_normvar_all[labels == l].sum(axis=0))
                for l in set(labels)
            ]
        elif data_type == 'epoch':
            ## TODO: this may no longer work!
            label_prefs = [
                self.keys[np.argmax(h_normvar_all[labels == l].sum(axis=0))][0]
                for l in set(labels)
            ]

        ind_label_sort = np.argsort(label_prefs)
        label_prefs = np.array(label_prefs)[ind_label_sort]
        # Relabel
        labels2 = np.zeros_like(labels)
        for i, ind in enumerate(ind_label_sort):
            labels2[labels == ind] = i
        labels = labels2

        # # Sort data by labels and by input connectivity
        # model = Model(save_name)
        # hp = model.hp
        # with tf.Session() as sess:
        #     model.restore(sess)
        #     var_list = sess.run(model.var_list)
        #
        # # Get connectivity
        # w_out = var_list[0].T
        # b_out = var_list[1]
        # w_in  = var_list[2][:n_input, :].T
        # w_rec = var_list[2][n_input:, :].T
        # b_rec = var_list[3]
        #
        # # nx, nh, ny = hp['shape']
        # nr = hp['n_eachring']
        #
        # sort_by = 'w_in'
        # if sort_by == 'w_in':
        #     w_in_mod1 = w_in[ind_active, :][:, 1:nr+1]
        #     w_in_mod2 = w_in[ind_active, :][:, nr+1:2*nr+1]
        #     w_in_modboth = w_in_mod1 + w_in_mod2
        #     w_prefs = np.argmax(w_in_modboth, axis=1)
        # elif sort_by == 'w_out':
        #     w_prefs = np.argmax(w_out[1:, ind_active], axis=0)
        #
        # ind_sort        = np.lexsort((w_prefs, labels)) # sort by labels then by prefs

        ind_sort = np.argsort(labels)

        labels = labels[ind_sort]
        self.h_normvar_all = h_normvar_all[ind_sort, :]
        self.ind_active = ind_active[ind_sort]

        self.n_clusters = n_clusters
        self.scores = scores
        self.n_cluster = n_cluster

        self.h_var_all = h_var_all
        self.normalization_method = normalization_method
        self.labels = labels
        self.unique_labels = np.unique(labels)

        self.model_dir = model_dir
        self.hp = hp
        self.data_type = data_type
        self.rules = hp['rules']
Пример #13
0
def plot_choicefamily_varytime(model_dir, rule):
    import seaborn as sns
    assert rule in ['dm1', 'dm2', 'contextdm1', 'contextdm2', 'multidm']
    savename = os.path.join(model_dir, 'varytime_' + rule + '.pkl')

    try:
        result = tools.load_pickle(savename)
    except FileNotFoundError:
        raise FileNotFoundError(
            'Run performance.compute_choicefamily_varytime first.')

    xdatas = result['xdatas']
    ydatas = result['ydatas']
    cohs = result['cohs']
    stim_times = xdatas[0]
    n_coh = len(xdatas)

    # Plot how the threshold varies with stimulus duration
    weibull = lambda x, a, b: 1 - 0.5 * np.exp(-(x / a)**b)
    xdata = cohs

    alpha_fits = list()
    for i in range(len(stim_times)):
        ydata = ydatas[:, i]
        res = minimize(lambda param: np.sum(
            (weibull(xdata, param[0], param[1]) - ydata)**2), [0.1, 1],
                       bounds=([1e-3, 1], [1e-5, 10]),
                       method='L-BFGS-B')
        alpha, beta = res.x
        alpha_fits.append(alpha)

    perfect_int = lambda x, b: -0.5 * x + b
    b, _ = curve_fit(perfect_int, np.log10(stim_times), np.log10(alpha_fits))

    fs = 7
    fig = plt.figure(figsize=(2.5, 1.5))
    ax = fig.add_axes([0.2, 0.25, 0.4, 0.6])
    ax.plot(np.log10(stim_times),
            np.log10(alpha_fits),
            'o-',
            color='black',
            label='model',
            markersize=3)
    ax.plot(np.log10(stim_times),
            -0.5 * np.log10(stim_times) + b,
            color='red',
            label='perfect int.')
    ax.set_xlabel('Stimulus duration (ms)', fontsize=fs)
    ax.set_ylabel('Discrim. thr. (x0.01)', fontsize=fs)
    ax.set_xticks(np.log10(np.array([200, 400, 800, 1600])))
    ax.set_xticklabels(['200', '400', '800', '1600'])
    ax.set_yticks(np.log10(np.array([0.005, 0.01, 0.02, 0.04])))
    ax.set_yticklabels(['0.5', '1', '2', '4'])
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    ax.tick_params(axis='both', which='major', labelsize=fs)
    ax.set_title(rule_name[rule], fontsize=fs)
    leg = plt.legend(fontsize=fs, frameon=False, bbox_to_anchor=[1, 1], loc=2)
    # plt.locator_params(axis='y', nbins=5)
    figname = 'varytime2_' + rule_name[rule].replace(' ', '')
    # figname = figname + model_dir
    if save:
        plt.savefig('figure/' + figname + '.pdf', transparent=True)

    # Chronometric curve
    figname = 'varytime_' + rule_name[rule].replace(' ', '')
    # figname = figname + model_dir
    plot_psychometric_varytime(xdatas,
                               ydatas,
                               figname,
                               labels=['{:0.3f}'.format(t) for t in 2 * cohs],
                               colors=sns.dark_palette("light blue",
                                                       n_coh,
                                                       input="xkcd"),
                               legtitle='Stim. 1 - Stim. 2',
                               rule=rule)
    def plot_PSTH(self, 
                epochs, 
                rules=None, 
                trial_list=None, 
                neuron_types=[('exh_neurons','mix_neurons')], 
                norm = True,
                separate_plot = False,
                fuse_rules = False):
        
        if trial_list is None:
            trial_list = self.trial_list
        if rules is None:
            rules = self.rules

        psth_to_plot = OrderedDict()

        for trial_num in trial_list:
            psth_to_plot[trial_num] = OrderedDict()
            for rule in rules:
                H = tools.load_pickle(self.model_dir+'/'+str(trial_num)+'/'+'H_'+rule+'.pkl')
                psth_to_plot[trial_num][rule] = OrderedDict()
                for epoch in epochs:
                    psth_to_plot[trial_num][rule][epoch] = OrderedDict()
                    for type_pair in neuron_types:
                        psth_to_plot[trial_num][rule][epoch][type_pair] = OrderedDict()
                        psth_neuron = list()
                        anti_dir_psth = list()

                        for n_type in type_pair:
                            
                            for neuron in self.neuron_info[trial_num][rule][epoch][n_type]:

                                sel_loc = np.argmax(self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'][neuron])
                                anti_loc = (sel_loc+len(self.in_loc_set[rule])//2)%len(self.in_loc_set[rule])

                                psth_temp = H[:,self.in_loc[rule] == sel_loc, neuron].mean(axis=1)
                                fix_level = H[self.epoch_info[rule]['fix1'][0]:self.epoch_info[rule]['fix1'][1], \
                                    self.in_loc[rule] == sel_loc, neuron].mean(axis=1).mean(axis=0)
                                if len(self.in_loc_set[rule])%2:
                                    anti_dir_psth_temp = (H[:,self.in_loc[rule] == anti_loc, neuron].mean(axis=1)+\
                                        H[:,self.in_loc[rule] == (anti_loc+1), neuron].mean(axis=1))/2.0
                                else:
                                    anti_dir_psth_temp = H[:,self.in_loc[rule] == anti_loc, neuron].mean(axis=1)

                                anti_dir_psth_norm = anti_dir_psth_temp/fix_level-1
                                psth_norm = psth_temp/fix_level-1
                                if norm:
                                    psth_neuron.append(psth_norm)
                                    anti_dir_psth.append(anti_dir_psth_norm)
                                else:
                                    psth_neuron.append(psth_temp)
                                    anti_dir_psth.append(anti_dir_psth_temp)

                        try:
                            psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'] = np.array(psth_neuron).mean(axis=0)
                        except:
                            pass
                        try:
                            psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'] = np.array(anti_dir_psth).mean(axis=0)
                        except:
                            pass

        for rule in rules:
            for epoch in epochs:
                for type_pair in neuron_types:
                    if not separate_plot:
                        fig_psth = plt.figure()
                    for trial_num in trial_list:

                        if separate_plot:
                            fig_psth = plt.figure()
                            color = None
                        else:
                            color = kelly_colors[(trial_list.index(trial_num)+1)%len(kelly_colors)]

                        try:
                            plt.plot(np.arange(len(psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir']))*self.hp['dt']/1000,
                            psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'],label=str(trial_num)+'sel',color=color)
                        except:
                            pass

                        try:
                            plt.plot(np.arange(len(psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir']))*self.hp['dt']/1000,
                            psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'],linestyle = '--',label=str(trial_num)+'anti',color=color)
                        except:
                            pass
                        if separate_plot:
                            plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair))
                            plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
                            type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/'
                            tools.mkdir_p(type_pair_folder)
                            plt.tight_layout()
                            plt.savefig(type_pair_folder+'PSTH-'+str(trial_num)+'.png',transparent=False,bbox_inches='tight')
                            plt.close(fig_psth)

                    if not separate_plot:
                        plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair))
                        plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
                        type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/'
                        tools.mkdir_p(type_pair_folder)
                        plt.tight_layout()
                        plt.savefig(type_pair_folder+'PSTH_all_'+str(trial_list[0])+'to'+str(trial_list[-1])+'.png',
                        transparent=False,bbox_inches='tight')
                        plt.close(fig_psth)

        plot_by_growth = dict()
        for rule in rules:
            plot_by_growth[rule] = dict()

            for epoch in epochs:
                plot_by_growth[rule][epoch] = dict()

                for type_pair in neuron_types:
                    plot_by_growth[rule][epoch][type_pair] = dict()
                    plot_by_growth[rule][epoch][type_pair]['less_than_I'] = dict()
                    plot_by_growth[rule][epoch][type_pair]['I_to_Y'] = dict()
                    plot_by_growth[rule][epoch][type_pair]['elder_than_Y'] = dict()

                    for key in plot_by_growth[rule][epoch][type_pair].keys():
                        plot_by_growth[rule][epoch][type_pair][key]['sel'] = list()
                        plot_by_growth[rule][epoch][type_pair][key]['anti'] = list()
                        plot_by_growth[rule][epoch][type_pair][key]['growth'] = list()

                    for trial_num in trial_list:
                        growth = self.log['perf_'+rule][trial_num//self.log['trials'][1]]
                        if growth <= self.hp['infancy_target_perf']:
                            plot_by_growth[rule][epoch][type_pair]['less_than_I']['sel'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['less_than_I']['anti'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['less_than_I']['growth'].append(growth)
                        elif growth <= self.hp['young_target_perf']:
                            plot_by_growth[rule][epoch][type_pair]['I_to_Y']['sel'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['I_to_Y']['anti'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['I_to_Y']['growth'].append(growth)
                        else:
                            plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['sel'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['anti'].append(\
                                psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'])
                            plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['growth'].append(growth)

                    for growth_key in plot_by_growth[rule][epoch][type_pair].keys():
                        for type_key, value in plot_by_growth[rule][epoch][type_pair][growth_key].items():
                            try:
                                plot_by_growth[rule][epoch][type_pair][growth_key][type_key] = np.array(value).mean(axis=0)
                            except:
                                pass

        for rule in rules:
            for epoch in epochs:
                for type_pair in neuron_types:
                    fig_psth = plt.figure()
                    for growth_key,value in plot_by_growth[rule][epoch][type_pair].items():
                        if growth_key == 'less_than_I':
                            color = 'green'
                        elif growth_key == 'I_to_Y':
                            color = 'blue'
                        else:
                            color = 'red'
                        
                        try:
                            plt.plot(np.arange(len(value['sel']))*self.hp['dt']/1000,value['sel'],
                            label = growth_key+'_'+str(value['growth'])[:4], color = color)
                        except:
                            pass

                        try:
                            plt.plot(np.arange(len(value['anti']))*self.hp['dt']/1000,value['anti'],
                            linestyle = '--',label = growth_key+'_'+str(value['growth'])[:4], color = color)
                        except:
                            pass
                    plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair))
                    plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
                    type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/'
                    tools.mkdir_p(type_pair_folder)
                    plt.tight_layout()
                    plt.savefig(type_pair_folder+'PSTH_bygrowth_'+str(trial_list[0])+'to'+str(trial_list[-1])+'.png',
                    transparent=False,bbox_inches='tight')
                    plt.close(fig_psth)

        if fuse_rules:
            ls_list = ['-','--','-.',':']
            for type_pair in neuron_types:
                for epoch in epochs:
                    fig_psth_fr = plt.figure()
                    for rule, linestyle in zip(rules,ls_list[0:len(rules)]):
                        for key,value in plot_by_growth[rule][epoch][type_pair].items():
                            if key == 'less_than_I':
                                color = 'green'
                            elif key == 'I_to_Y':
                                color = 'blue'
                            else:
                                color = 'red' 

                            try:
                                plt.plot(np.arange(len(value['sel']))*self.hp['dt']/1000,value['sel'],
                                label = rule+'_'+key+'_'+str(value['growth'])[:4], color = color, linestyle=linestyle)
                            except:
                                pass

                    plt.legend()
                    plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair))
                    plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0)
                    plt.tight_layout()
                    plt.savefig('figure/figure_'+self.model_dir+'/'+'-'.join(rules)+'_'+epoch+'_'+'-'.join(type_pair)\
                        +str(trial_list[0])+'to'+str(trial_list[-1])+'.png',
                    transparent=False,bbox_inches='tight')
                    plt.close(fig_psth_fr)
Пример #15
0
Файл: WAE.py Проект: enix45/WAE
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = X_train.astype('float32')
    X_train = np.reshape(X_train, (X_train.shape[0], 784))
    X_train /= 255.
    model = WAE_MMD(dims=[784, 512, 64],
                    loss='L2',
                    activation='lrelu',
                    z_dim=8,
                    phase='train',
                    scale=1,
                    batch_size=50)
    print('Pre-train')
    model.fit(data=X_train, nb_epoch=20, w=[0, 0.05])
    #print('Pre-train with regularizer')
    #model.fit(data = X_train, nb_epoch = 30, w = [5, 0])
    model.save_weight('pre_weight.pkl')

    model = WAE_MMD(dims=[784, 512, 64],
                    loss='L2',
                    activation='lrelu',
                    z_dim=8,
                    phase='train',
                    scale=1,
                    batch_size=50,
                    params=load_pickle('pre_weight.pkl'))
    model.fit(data=X_train,
              nb_epoch=300,
              w=[5000, 0],
              sample_path='samples/mnist_')
Пример #16
0
def create_dataset(image_dict, label_dict):
    print('[INFO] Creating training dataset on %d image(s).' %
          len(image_dict.keys()))

    X = []
    y = []

    for fn in image_dict.keys():
        img = image_dict[fn]
        label = label_dict[fn]

        cached_fn = os.path.join(
            cache_dir,
            fn.split(os.path.sep)[-1].split('.')[0] + ".pkl")
        if os.path.isfile(cached_fn):
            features = load_pickle(cached_fn)
        else:
            features = create_features(img)
            dump2pickle(features, cached_fn)

        label = label[h_ind:-h_ind, h_ind:-h_ind]
        labels = label.reshape(label.shape[0] * label.shape[1], -1)
        X.append(features)
        y.append(labels)
    X = np.array(X)
    X = X.reshape(X.shape[0] * X.shape[1], -1)

    y = np.array(y)
    y = (y > 0) * 1
    y = y.astype(dtype=np.uint8)
    y = y.reshape(y.shape[0] * y.shape[1], -1)

    # delete ambigous samples
    indices = np.where(y.sum(axis=1) != 1)
    y = np.delete(y, indices, axis=0)
    X = np.delete(X, indices, axis=0)

    # unused
    def add_background_label(y):
        expanded = np.c_[y, np.zeros(y.shape[0])].astype(np.uint8)
        bg_label = np.zeros(len(classes) + 1, dtype=np.uint8)
        bg_label[-1] = 1
        expanded[np.where(~expanded.any(axis=1))] = bg_label
        return expanded

    y_int = onehot2int(y)

    from collections import Counter
    print("Original dataset shape {}".format(Counter(y_int)))
    if oversample:
        from imblearn.over_sampling import RandomOverSampler
        ros = RandomOverSampler(random_state=random_state)
        X, y_int = ros.fit_resample(X, y_int)
        print("Resampled dataset shape {}".format(Counter(y_int)))
        y = int2onehot(y_int)

    scaler = MinMaxScaler()
    scaler.fit(X)
    X = scaler.transform(X)

    dump2pickle(scaler, os.path.join(model_dir, scaler_fn))

    return X, y
Пример #17
0
    cached_fn = os.path.join(cache_dir, fn.split(
        os.path.sep)[-1].split('.')[0]+".pkl")
    if os.path.isfile(cached_fn):
        features = load_pickle(cached_fn)
    else:
        features = train.create_features(img)
        dump2pickle(features, cached_fn)
    scaler = load_pickle(os.path.join(model_dir, scaler_fn))
    features = features.reshape(-1, features.shape[1])
    features = scaler.transform(features)
    model_predictions = model.predict_proba(features)
    model_predictions = prob2class(model_predictions)
    predictions_image = model_predictions.reshape(
        [img.shape[0]-2*border, img.shape[1]-2*border, -1])
    return predictions_image


if __name__ == '__main__':
    filelist = glob(os.path.join(infere_dir, '*.jpg'))

    print('[INFO] Running inference on %s test images' % len(filelist))
    model = load_pickle(os.path.join(model_dir, model_fn))

    for fn in filelist:
        print('[INFO] Processing images:', fn.split(os.path.sep)[-1])
        inference_img = compute_prediction(fn, model)
        mapped_img = class2bgr(inference_img, palette_bgr)
        output_fn = os.path.join(output_dir, fn.split(
            os.path.sep)[-1].split('.')[0]+".png")
        cv2.imwrite(output_fn, mapped_img)