コード例 #1
0
def plot_timecourse(file_path: Union[Path, str]):
    log = unpack(file_path)
    mode_switches = flatten_events(log['event'],
                                   ('passive', 'wait_push', 'push', 'pull'))
    trial_switches = flatten_events(log['event'],
                                    ('trial', 'reward', 'intertrial'))
    trials = events2trials(trial_switches, mode_switches)
    move_mode = "passive"
    mode_color = {
        'passive': 'gray',
        'wait_push': 'blue',
        'push': 'red',
        'pull': 'green'
    }
    with Figure() as axes:
        for intertrial, trial, reward, mode in trials:
            axes[0].plot((intertrial, trial), (0, 0),
                         color=mode_color[move_mode])
            move_mode = mode
            axes[0].plot((trial, trial), (0, 1), color=mode_color[move_mode])
            if reward > 0:
                axes[0].plot((trial, reward), (1, 1),
                             color=mode_color[move_mode])
                axes[0].plot((reward, reward), (1, 2),
                             color=mode_color[move_mode])
                axes[0].plot((reward, reward + 5), (2, 2),
                             color=mode_color[move_mode])
            else:
                axes[0].plot((trial, trial + 5), (1, 1),
                             color=mode_color[move_mode])
コード例 #2
0
def draw_pca():
    case_idx = [idx for idx, case in enumerate(mice) if case.id == "14032" and case.fov == 1][0]
    case = mice[case_idx: case_idx + 1]
    spikes, clusters = get_result([x.name for x in case], [res_trial_neuron, res_cluster])
    spike, cluster = spikes[0], clusters[0]
    neuron = scale(np.swapaxes(spike.values, 0, 1).reshape(spike.shape[1], -1), axis=0)
    y = quantize(cluster)
    neuron = PCA(20).fit_transform(neuron)
    svc = SVC()
    svc.fit(neuron, y)

    idx0, idx1 = 0, 5
    x_max, x_min = neuron[:, idx0].max(), neuron[:, idx0].min()
    y_max, y_min = neuron[:, idx1].max(), neuron[:, idx1].min()
    XX, YY = np.mgrid[x_min:x_max:100j, y_min:y_max:100j]  # type: ignore
    samples = np.zeros((10000, 20), dtype=np.float)
    samples[:, idx0] = XX.ravel()
    samples[:, idx1] = YY.ravel()
    Z = svc.decision_function(samples).reshape(XX.shape)
    primary_color, secondary_color = "#F8766D", "#619CFF"
    with Figure(fig_folder.joinpath("classifier-decision.png"), figsize=(12, 12),
                despine={'bottom': True, 'left': True}) as axes:
        ax = axes[0]
        mask = y > 0
        ax.pcolormesh(XX, YY, Z, cmap=get_gradient_cmap(secondary_color, primary_color))
        ax.contour(XX, YY, Z, colors=['k'], linestyles=['-'], levels=[.25])
        ax.scatter(neuron[mask, idx0], neuron[mask, idx1], color=primary_color, s=75, edgecolors='k')
        ax.scatter(neuron[~mask, idx0], neuron[~mask, idx1], color=secondary_color, s=75, edgecolors='k')
コード例 #3
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def example_curve():
    xy, predicts = get_result([x.name for x in mice[1:2]],
                              [res_align_xy, res_predict], "astrocyte")
    with Figure() as ax:
        ax = ax[0]
        ax.plot(xy[0][1].values, color='blue')
        ax.plot(predicts[0], color='orange')
コード例 #4
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def plot_cno_saline():
    data: pd.DataFrame = pd.read_csv(
        analysis_folder.joinpath("decoder_cno.csv"),
        usecols=[1, 2, 3],
        index_col=[2, 1]).sort_index()  # type: ignore
    group_strs = ('saline', 'cno')
    paired_data = [data.loc[treat, 'mutual_info'] for treat in group_strs]
    p_value = wilcoxon(*paired_data).pvalue
    res = ["median: "] + str(data.groupby("treat").median()).split('\n')[2:]
    print_stats("saline vs. cno in Gq", res + [f"paired wilcox: p={p_value}"])
    with Figure(fig_folder.joinpath("decoder-pair-cno.svg"),
                figsize=(6, 9)) as axes:
        boxplots = plots.boxplot(axes[0],
                                 paired_data,
                                 whis=(10., 90.),
                                 zorder=1,
                                 showfliers=False,
                                 colors=colors[2:4],
                                 widths=0.65)
        plots.dots(axes[0], paired_data, zorder=3, s=24)
        axes[0].set_xticklabels(["Gq Saline", "Gq CNO"])
        plots.annotate_boxplot(axes[0], boxplots, 24, 1.2, [((0, 1), p_value)])
        [
            axes[0].plot([1, 2], x, color='gray')
            for x in np.array(paired_data).T
        ]
コード例 #5
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def single_dist():
    data: pd.DataFrame = pd.read_csv(
        analysis_folder.joinpath("single_power.csv"),
        usecols=[1, 2, 3, 4],
        index_col=[2, 1, 3]).sort_index()  # type: ignore
    bins = np.linspace(data['mi'].min(), data['mi'].max(), 50)  # type: ignore
    group_strs = ("wt", "glt1", "dredd")

    def scale_hist(x):
        res = np.histogram(x, bins=bins)[0]
        return res / res.sum()

    with Figure(fig_folder.joinpath("single-dist.svg"),
                figsize=(9, 9)) as axes:
        lines = [
            data.loc[x, ].groupby("case_id").apply(scale_hist).mean()[1:]
            for x in group_strs
        ]
        boxes = list()
        bin_size = bins[1] - bins[0]
        for idx, (color, line) in enumerate(zip(colors, lines)):
            box = axes[0].bar(
                (np.arange(len(line)) + idx / len(lines)) * bin_size,
                line,
                facecolor=color,
                edgecolor=color,
                alpha=0.7,
                align='edge',
                width=bin_size / len(lines))
            boxes.append(box)
        axes[0].legend(boxes, ["WT", "GLT1", "Gq"])
        axes[0].set_xlabel("Mutual Information (bit/sample)")
        axes[0].set_ylabel("Mean Density")
コード例 #6
0
def draw_stacked_bar(cluster_file: ClusterFile):
    days = ('5', '10', '13', '14')
    res = [[len(cluster_file[day].get(str(cluster_id), []))
            for cluster_id in range(1, 15)] for day in days]
    with Figure() as (ax,):
        stacked_bar(ax, res, COLORS)
        ax.set_xticks(range(len(days)), days)
コード例 #7
0
def draw_noise(data_files: Dict[int, File], neuron_id: int, params: MotionParams):
    last_day = max(data_files.keys())
    lever = load_mat(data_files[last_day]['response'])
    neuron_rate = data_files[last_day].attrs['frame_rate']
    neurons = common_axis([DataFrame.load(x['spike']) for x in data_files.values()])
    good, bad, anti = classify_cells(motion_corr(
        lever, neurons[-1], neuron_rate, 16000, params), 0.001)
    amp = list()
    corrs: Dict[str, List[List[float]]] = {'good': [], 'unrelated': [], 'between': []}
    for (day_id, data_file), neuron in zip(data_files.items(), neurons):
        if day_id == last_day:
            continue
        lever = load_mat(data_file['response'])
        corrs['good'].append(_take_triu(noise_autocorrelation(lever, neuron[good], neuron_rate)))
        corrs['unrelated'].append(_take_triu(noise_autocorrelation(lever, neuron[bad | anti], neuron_rate)))
        corrs['between'].append(_take_triu(noise_correlation(lever, neuron[good], neuron[bad | anti], neuron_rate)))
        lever.center_on("motion", **params)
        neuron_trials = fold_by(neuron, lever, neuron_rate, True)
        amp.append(neuron_trials.values[np.argwhere(neuron.axes[0] == neuron_id)[0, 0], :, :].max(axis=1))
    with Figure(join(project_folder, 'report', 'img', f'noise_corr_{neuron_id}.svg')) as (ax,):
        day_ids = [x for x in data_files.keys() if x != last_day]
        for idx, (group_str, group) in enumerate(corrs.items()):
            ax.errorbar(day_ids, [np.mean(x) for x in group],
                        yerr=[_sem(x) for x in group], color=COLORS[idx], label=group_str)
        ax2 = ax.twinx()
        ax2.errorbar(day_ids, [np.mean(x) for x in amp], [_sem(x) for x in amp], color=COLORS[-1])
        ax.set_title(str(neuron_id))
        ax.legend()
コード例 #8
0
ファイル: run_decoding.py プロジェクト: Palpatineli/lever
def svr_parameters(data_file: File, info: Dict[str, str]):
    lever = load_mat(data_file['response'])
    values = devibrate(lever.values[0], sample_rate=lever.sample_rate)
    y = InterpolatedUnivariateSpline(lever.axes[0],
                                     values)(data_file['spike']['y'])[1:]
    X = data_file['spike']['data'][:, 1:]
    gammas = np.linspace(-8, -5, 12, endpoint=False)
    Cs = np.linspace(3, 15, 12, endpoint=False)

    def pred(gamma, C):
        hat = cross_predict(X,
                            y,
                            svr.predictor_factory(y,
                                                  gamma=10**gamma,
                                                  C=C,
                                                  epsilon=1E-3),
                            section_mi=False)
        return mutual_info(y, hat)

    res = map_table(pred, gammas, Cs)
    save_path = join(res_folder,
                     f"svr_params_test_{info['id']}_{info['session']}.npz")
    np.savez_compressed(save_path, values=np.asarray(res), axes=[gammas, Cs])
    res_df = DataFrame(np.asarray(res), [gammas, Cs])
    with Figure() as (ax, ):
        labeled_heatmap(ax, res_df.values, res_df.axes[1], res_df.axes[0])
    print('done')
コード例 #9
0
def draw_hierarchy(data_files: Dict[int, File]):
    neurons = common_axis([DataFrame.load(x['spike']) for x in files.values()])
    for (day_id, data_file), neuron in zip(files.items(), neurons):
        lever = load_mat(data_file['response'])
        corr_mat = noise_autocorrelation(lever, neuron, data_file.attrs['frame_rate'])
        with Figure() as (ax,):
            ax.set_title(f"day-{day_id:02d}")
            fancy_dendrogram(linkage(corr_mat, 'average'), ax=ax)
コード例 #10
0
def draw_classify_neurons(data_file: File, neuron_ids: Optional[np.ndarray] = None):
    lever = load_mat(data_file['response'])
    neuron = DataFrame.load(data_file['spike'])
    if neuron_ids is not None:
        neuron = neuron[search_ar(neuron_ids, neuron.axes[0]), :]
    neuron_rate = data_file.attrs['frame_rate']
    corr = motion_corr(lever, neuron, neuron_rate, 16000, motion_params)
    good, bad, anti = [corr[x, 0] for x in classify_cells(corr, 0.001)]
    with Figure(join(img_folder, "good_unrelated_cmp.svg"), (4, 6)) as ax:
        ax[0].bar((0, 1), [good.mean(), np.r_[bad, anti].mean()], yerr=[_sem(good), _sem(np.r_[bad, anti])])
コード例 #11
0
def draw_threshold():
    case_idx = [idx for idx, case in enumerate(mice) if case.id == "14032" and case.fov == 1][0]
    case = mice[case_idx: case_idx + 1]
    linkage = get_result([x.name for x in case], [res_linkage])[0][0]
    threshold = get_threshold(linkage)
    with Figure(proj_folder.joinpath("report", "fig", "threshold-sample.svg"), (6, 6)) as axes:
        ax = axes[0]
        dendrogram(linkage, color_threshold=threshold, ax=ax)
        ax.axhline(threshold)
        ax.set_xlabel("Trials")
        ax.set_ylabel("Warp path length (a.u.)")
コード例 #12
0
def draw_boxplot():
    data = pd.read_csv(proj_folder.joinpath("data", "analysis", "classifier_power_validated.csv"))
    data = data[data.type != "none"]
    means = data.groupby(["id", "session", "group", "type"]).mean().reset_index()
    width = 0.6
    with Figure(fig_folder.joinpath("classifier-compare.svg"), (10, 6)) as axes:
        sns.boxplot(x="group", y="precision", hue="type", data=data, notch=True, width=width, whis=1.0, ax=axes[0])
        for idx, group in enumerate(('wt', 'gcamp6f', 'glt1', 'dredd')):
            temp = pd.pivot_table(means[means.group == group], index=['id', 'session'], columns='type', values='precision')
            for value in np.fliplr(temp.values):
                axes[0].plot([idx - width / 4, idx + width / 4], value, color="#555753")
コード例 #13
0
def draw_rasterplot(day: int, data_file: File, neuron_id: int, params: MotionParams):
    lever = load_mat(data_file['response'])
    lever.center_on('motion', **params)
    neurons = DataFrame.load(data_file['spike'])
    neuron_rate = data_file.attrs['frame_rate']
    traces = fold_by(neurons, lever, neuron_rate, True)[np.flatnonzero(neurons.axes[0] == neuron_id)[0], :, :]
    mask = np.all(traces.values > 0, axis=1)
    onset = int(round(params['pre_time'] * neuron_rate))
    with Figure(join(img_folder, 'neuron-trace', f"raster-day-{day}.svg"), (2, 4)) as (ax,):
        labeled_heatmap(ax, traces[mask, :] - traces[mask, 0: onset].mean(axis=1, keepdims=True), cmap="coolwarm")
        ax.set_title(f"day-{day}")
コード例 #14
0
def main():
    trial_neurons = get_result([x.name for x in mice][0:1], [res_trial_neuron],
                               'trial-neuron-2s-run')
    values = trial_neurons[0][0].values
    with Figure(fig_folder.joinpath("classifier", "example-neurons.svg"),
                show=True) as axes:
        for id_neuron, neuron in enumerate(values[:20, 0:4, :]):
            for id_trial, trial in enumerate(neuron):
                axes[0].plot(range(id_trial * 11, id_trial * 11 + 10),
                             trial / trial.max() * 5 + id_neuron * 6,
                             color='red')
コード例 #15
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def save_example():
    xy, predicts = get_result([x.name for x in mice[1:2]],
                              [res_align_xy, res_predict], "astrocyte")
    dataframe = list()
    for (idx, trace), predict in zip(enumerate(xy[0][1].values), predicts[0]):
        dataframe.append((idx / 5.0, trace, predict))
    df = pd.DataFrame(dataframe, columns=["time", "trajectory", "predicted"])
    df.to_csv(proj_folder.joinpath("data", "analysis", "decoder_example.csv"))
    with Figure() as ax:
        ax = ax[0]
        ax.plot(df["time"], df["trajectory"], color="blue")
        ax.plot(df["time"], df["predicted"], color="orange")
コード例 #16
0
def draw_neuron(day: int, data_file: File, neuron_id: int, params: MotionParams):
    """Draw one neuron in trial for one session, with bootstrapped spread as shadow."""
    lever = load_mat(data_file['response'])
    lever.center_on("motion", **params)
    neuron = DataFrame.load(data_file['spike'])
    traces = fold_by(neuron, lever, data_file.attrs['frame_rate'])
    traces = traces[np.flatnonzero(traces.axes[0] == neuron_id)[0], :, :]
    mask = np.all(traces.values > 0, axis=1)
    pre_value = traces.values[mask, 0: int(round(params['pre_time'] * lever.sample_rate))].mean(axis=1, keepdims=True)
    trace_values = traces.values[mask, :] - pre_value
    with Figure(join(img_folder, "neuron-trace", f"day-{day:02d}.svg"), (1, 4)) as (ax,):
        tsplot(ax, trace_values, time=traces.axes[2], color=COLORS[4])
        ax.set_title(f"day_{day:02d}")
コード例 #17
0
def compare_neurons(day: int, data_file_0: File, data_file_1: File, params: MotionParams):
    values = list()
    for data_file in (data_file_0, data_file_1):
        lever = load_mat(data_file['response'])
        lever.center_on('motion', **params)
        lever.fold_trials()
        pre_value = lever.values[0, :, 0: int(round(params['pre_time'] * lever.sample_rate))]\
            .mean(axis=1, keepdims=True)
        values.append(lever.values[0, :, :] - pre_value)
    with Figure(join(img_folder, 'neuron-trace', f"comp-day-{day:02d}.svg"), (1, 4)) as (ax,):
        tsplot(ax, values[0], time=lever.axes[2], color=COLORS[0])
        tsplot(ax, values[1], time=lever.axes[2], color=COLORS[1])
        ax.set_title(f"day-{day:02d}")
コード例 #18
0
def draw_neuron_corr(data_files: Dict[int, File], params: MotionParams, fov_id: str = None):
    neurons = common_axis([DataFrame.load(x['spike']) for x in data_files.values()])
    last_day = max(data_files.keys())
    lever = load_mat(data_files[last_day]['response'])
    neuron_rate = data_files[last_day].attrs['frame_rate']
    good, bad, anti = classify_cells(motion_corr(
        lever, neurons[-1], neuron_rate, 16000, params), 0.001)
    result_list = list()
    for (day, data_file), neuron in zip(data_files.items(), neurons):
        lever.center_on('motion')  # type: ignore
        motion_neurons = fold_by(neuron, lever, neuron_rate, True)
        result_list.append([reliability(motion_neuron) for motion_neuron in motion_neurons.values])
    result = np.array(result_list)

    with Figure(join(img_folder, ("neuron_corr.svg" if fov_id is None else f"{fov_id}.svg"))) as ax:
        ax[0].plot(list(data_files.keys()), result[:, good])
コード例 #19
0
def draw_network_graph(data_files: Dict[int, File], params: MotionParams, threshold: int = 16000):
    """Draw neuron functional connection for each session, with neurons colored by the last session.
    Args:
        data_files: {day_id: int, data_file: File}
        params: classify_cells need ["quiet_var", "window_size", "event_thres", "pre_time"]
        threshold: threshold for motion_corr, single linked cluster distance
    """
    last_day = data_files[max(data_files.keys())]
    neurons = common_axis([DataFrame.load(x['spike']) for x in data_files.values()])
    neuron_rate = last_day.attrs['frame_rate']
    final_corr_mat = noise_autocorrelation(load_mat(last_day['response']), neurons[-1], neuron_rate)
    categories = classify_cells(motion_corr(last_day, neurons[-1], neuron_rate, threshold, params), 0.001)
    layout = corr_graph.get_layout(final_corr_mat, neurons[-1].axes[0])
    for (day_id, data_file), neuron in zip(data_files.items(), neurons):
        corr_mat = noise_autocorrelation(load_mat(data_file['response']), neuron, neuron_rate)
        with Figure(join(img_folder, f"network-day-{day_id:02d}.svg")) as ax:
            corr_graph.corr_plot(ax[0], corr_mat, categories, neuron.axes[0], layout=layout)
    print('done')
コード例 #20
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def pop_decoder_power():
    def temp(data, ax, name):
        res = ["median: "] + str(
            data.groupby('group').median()).split('\n')[2:]
        group_names = ('wt', 'glt1', 'dredd', "gcamp6f")
        group_strs = ["WT", "GLT1", "Gq", "gcamp6f"]
        annotation = list()
        for (idx, x), (idy, y) in combinations(enumerate(group_names), 2):
            p = mannwhitneyu(data.loc[x, ], data.loc[y, ], False,
                             'two-sided').pvalue
            res.append(f"mann u, {x} v. {y} p={p}")
            if p < 0.05:
                annotation.append(((idx, idy), p))
        print_stats(name, res)
        values = [data.loc[x, "mutual_info"] for x in group_names]
        boxplots = plots.boxplot(ax,
                                 values,
                                 whis=(10., 90.),
                                 zorder=1,
                                 showfliers=False,
                                 colors=colors)
        plots.dots(ax, values, zorder=3, s=24, jitter=0.02)
        ax.set_xticklabels(group_strs)
        plots.annotate_boxplot(ax, boxplots, 24, 1.2, annotation)

    with Figure(fig_folder.joinpath("decoder-comp.svg"),
                figsize=(8, 9),
                grid=(1, 2)) as axes:
        axes[0].get_shared_y_axes().join(*axes)
        data: pd.DataFrame = pd.read_csv(
            analysis_folder.joinpath("decoder_power.csv"),
            index_col=[0])  # type: ignore
        data = data.set_index(["group", "session"],
                              append=True).reorder_levels(
                                  ["group", "id",
                                   "session"]).sort_index()  # type: ignore
        data = data.drop([('gcamp6f', 51551, 4)], axis='index')
        temp(data, axes[0], "pop-power by fov")
        temp(
            data.groupby(["group", "id"]).mean(), axes[1], "pop-power by case")

    data.loc['gcamp6f']
コード例 #21
0
ファイル: run_decoding.py プロジェクト: Palpatineli/lever
def single_power():
    """Requires data file from run_svr_power."""
    with open(join(res_folder, "svr_power.pkl"), 'rb') as fp:
        result = pkl.load(fp)
    ind_scores = {x: [a[1] for a in y] for x, y in result.items()}

    def fn(x):
        return np.cumsum(x / sum(x))

    with Figure(grid=(1, 3)) as (ax1, ax2, ax3):
        ind = [sorted(x, reverse=True) for x in ind_scores['wt']]
        [ax1.plot(fn(x), color='blue', alpha=0.5) for x in ind]
        ind = [sorted(x, reverse=True) for x in ind_scores['glt1']]
        [ax2.plot(fn(x), color='red', alpha=0.5) for x in ind]
        ind = [sorted(x, reverse=True) for x in ind_scores['dredd']]
        [ax3.plot(fn(x), color='green', alpha=0.5) for x in ind]

    wt = [
        np.divide(sorted(x[0:50], reverse=True), sum(x[0:50]))
        for x in ind_scores['wt']
    ]
    wt_no = np.hstack([np.arange(len(a)) for a in wt])
    glt1 = [
        np.divide(sorted(x[0:50], reverse=True), sum(x[0:50]))
        for x in ind_scores['glt1']
    ]
    glt1_no = np.hstack([np.arange(len(a)) for a in glt1])
    dredd = [
        np.divide(sorted(x[0:50], reverse=True), sum(x[0:50]))
        for x in ind_scores['dredd']
    ]
    dredd_no = np.hstack([np.arange(len(a)) for a in dredd])
    plt.plot(wt_no)
    plt.show()
    from scipy.stats import linregress
    slope, _, _, _, std = linregress(wt_no, np.log(np.hstack(wt)))
    wt_res = (slope - std * 2.58, slope + std * 2.58)
    slope, _, _, _, std = linregress(glt1_no, np.log(np.hstack(glt1)))
    glt1_rest = (slope - std * 2.58, slope + std * 2.58)
    slope, _, _, _, std = linregress(dredd_no, np.log(np.hstack(dredd)))
    dredd_rest = (slope - std * 2.58, slope + std * 2.58)
コード例 #22
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def single_dist_moments():
    data = pd.read_csv(analysis_folder.joinpath("single_power.csv"),
                       usecols=[1, 2, 3, 4],
                       index_col=[2, 1, 3]).sort_index()
    groups = ("wt", "glt1", "dredd")
    group_strs = ("WT", "GLT1", "Gq")

    def temp(data, ax, name, jitter=0.1):
        res = ["median: "]
        res.extend(str(data.groupby("group").median()).split("\n")[1:-1])
        annotation = list()
        for (idx, x), (idy, y) in combinations(enumerate(groups), 2):
            p = mannwhitneyu(data.loc[x, ], data.loc[y, ]).pvalue
            if p < 0.05:
                res.append(f"mann u {x} v. {y}, p={p}")
                annotation.append(((idx, idy), p))
        print_stats(name, res)
        values = [data.loc[x, ].tolist() for x in groups]
        boxplots = plots.boxplot(ax,
                                 values,
                                 zorder=1,
                                 whis=(10., 90.),
                                 showfliers=False,
                                 colors=colors)
        plots.dots(ax, values, zorder=3, s=24, jitter=jitter)
        plots.annotate_boxplot(ax, boxplots, 24, 1.2, annotation)
        ax.set_xticklabels(group_strs)
        ax.set_ylabel(name)

    with Figure(fig_folder.joinpath("single-dist-moments.svg"),
                figsize=(5, 9),
                grid=(2, 1)) as axes:
        medians = data.groupby("group").apply(
            lambda x: x.groupby("case_id").median()["mi"])
        temp(medians, axes[0], "Median", 0.001)
        axes[0].set_ylim(-0.001, 0.036)
        skewness = data.groupby("group").apply(
            lambda x: x.groupby("case_id").apply(lambda y: skew(y)[0]))
        temp(skewness, axes[1], "Skewness", 0.1)
コード例 #23
0
ファイル: decoding.py プロジェクト: Palpatineli/lever
def example_validation(xy, predicts):
    xy, predicts = get_result([x.name for x in mice[1:2]],
                              [res_align_xy, res_predict], "astrocyte")
    with Figure(fig_folder.joinpath("decoder_validation.svg"), (9, 6)) as ax:
        ax = ax[0]
        ax.plot(np.arange(1800), xy[0][1].values[0:1800], color='blue')
        ax.plot(np.arange(1800, 2100),
                xy[0][1].values[1800:2100] - 5.0,
                color='blue')
        ax.plot(np.arange(2100, 3000),
                xy[0][1].values[2100:3000],
                color='blue')
        ax.plot(np.arange(1800, 2100),
                predicts[0][1800:2100] - 10.0,
                color='orange')
        for idx, neuron in enumerate(xy[0][0].values[:20, :]):
            scaled = neuron * 2 / neuron.max()
            ax.plot(np.arange(1800), scaled[:1800] + 5.0 + idx, color='red')
            ax.plot(np.arange(1800, 2100),
                    scaled[1800:2100] + 10.0 + idx,
                    color='red')
            ax.plot(np.arange(2100, 3000),
                    scaled[2100:3000] + 5.0 + idx,
                    color='red')