def test_get_timeseries_tt_negativeistop():
    data = list(range(100, 200, 10))
    ts = TimeSeries(name='test_timeseries',
                    data=data,
                    unit='m',
                    starting_time=0.,
                    rate=1.0)

    tt = get_timeseries_tt(ts, istop=-1)
    np.testing.assert_array_equal(tt, [0., 1., 2., 3., 4., 5., 6., 7.])
예제 #2
0
def test_get_timeseries_tt_negativeistop():
    data = list(range(100, 200, 10))
    ts = TimeSeries(name="test_timeseries",
                    data=data,
                    unit="m",
                    starting_time=0.0,
                    rate=1.0)

    tt = get_timeseries_tt(ts, istop=-1)
    np.testing.assert_array_equal(tt, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
def test_get_timeseries_tt_infstarting_time():
    data = list(range(100, 200, 10))
    ts = TimeSeries(name='test_timeseries',
                    data=data,
                    unit='m',
                    starting_time=np.inf,
                    rate=1.0)

    tt = get_timeseries_tt(ts)
    np.testing.assert_array_equal(tt, [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
예제 #4
0
def test_get_timeseries_tt_infstarting_time():
    data = list(range(100, 200, 10))
    ts = TimeSeries(name="test_timeseries",
                    data=data,
                    unit="m",
                    starting_time=np.inf,
                    rate=1.0)

    tt = get_timeseries_tt(ts)
    np.testing.assert_array_equal(
        tt, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
예제 #5
0
    def trials_psth(self,
                    before=1.5,
                    after=1.5,
                    figsize=(6, 6)):  # time_window
        """
        Trial data by event times and plot

        Parameters
        ----------
        before: float
            Time before that event (should be positive)
        after: float
            Time after that event
        figsize: tuple, optional

        Returns
        -------
        matplotlib.Figure

        """
        # mask = (self.events > time_window[0]) \
        #        & (self.events < time_window[1])
        # active_events = self.events[mask]
        starts = self.events - before
        stops = self.events + after

        trials = align_by_times_with_rate(self.spatial_series, starts, stops)

        if trials.size == 0:
            return print("No trials present")

        tt = get_timeseries_tt(self.spatial_series,
                               istart=self.spatial_series.starting_time)
        zero_ind = before * (1 / (tt[1] - tt[0]))
        if len(np.shape(trials)) == 3:
            diff_x = trials[:, :, 0].T - trials[:, int(zero_ind), 0]
            diff_y = trials[:, :, 1].T - trials[:, int(zero_ind), 1]
            diffs = np.dstack([diff_x, diff_y])
            distance = np.linalg.norm(diffs, axis=2)
        elif len(np.shape(trials)) == 2:
            diff_x = trials[:, :].T - trials[:, int(zero_ind)]
        elif len(np.shape(trials)) == 1:
            diff_x = trials

        fig, axs = plt.subplots(1, 1, figsize=figsize)
        axs.set_title("Event-triggered Wrist Displacement")

        self.show_psth(distance, axs, before, after)
        return fig
예제 #6
0
        def on_change(change):
            time_window = self.controls["time_window"].value
            istart = timeseries_time_to_ind(timeseries, time_window[0])
            istop = timeseries_time_to_ind(timeseries, time_window[1])

            tt = get_timeseries_tt(timeseries, istart, istop)
            yy, units = get_timeseries_in_units(timeseries, istart, istop)

            with self.out_fig.batch_update():
                if len(yy.shape) == 1:
                    self.out_fig.data[0].x = tt
                    self.out_fig.data[0].y = yy
                else:
                    for k, key in enumerate(POSITION_KEYS):
                        data = positions[key]
                        yy, units = get_timeseries_in_units(
                            data, istart, istop)

                        for i, dd in enumerate(yy.T):
                            self.out_fig.data[(k * data_dim) + i].x = tt
                            self.out_fig.data[(k * data_dim) + i].y = dd
 def test_get_timeseries_tt_timestamp(self):
     tt = get_timeseries_tt(self.ts)
     np.testing.assert_array_equal(tt,
                                   [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
예제 #8
0
 def test_get_timeseries_tt_timestamp(self):
     tt = get_timeseries_tt(self.ts_rate)
     np.testing.assert_array_equal(
         tt, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
예제 #9
0
    def tab1(self, nwb_file):
        position_keys = list(nwb_file.processing["behavior"].
                             data_interfaces["Position"].spatial_series.keys())
        spatial_series = nwb_file.processing["behavior"].data_interfaces[
            "Position"][position_keys[0]]
        tt = get_timeseries_tt(spatial_series,
                               istart=spatial_series.starting_time)
        time_trace_window_controller = StartAndDurationController(tmax=tt[-1],
                                                                  tmin=tt[0],
                                                                  start=0,
                                                                  duration=5)
        reach_arm = (nwb_file.processing["behavior"].
                     data_interfaces["ReachEvents"].description)
        reach_arm = map(lambda x: x.capitalize(), reach_arm.split("_"))
        reach_arm = list(reach_arm)
        reach_arm = "_".join(reach_arm)

        jointpos_widget = AllPositionTracesPlotlyWidget(
            nwb_file.processing["behavior"].data_interfaces["Position"]
            [reach_arm],
            foreign_time_window_controller=time_trace_window_controller,
        )
        text = "(b) Movement segments"
        jointpos_label = widgets.HTML(value=f"<b><font size=4>{text}</b>")
        jointpos = widgets.VBox([jointpos_label, jointpos_widget],
                                layout=self.box_layout)

        skeleton_widget = SkeletonPlot(
            nwb_file.processing["behavior"].data_interfaces["Position"],
            foreign_time_window_controller=time_trace_window_controller,
        )
        text = "(a) Tracked joints"
        skeleton_label = widgets.HTML(value=f"<b><font size=4>{text}</b>")
        skeleton = widgets.VBox([skeleton_label, skeleton_widget],
                                layout=self.box_layout)

        ecog_widget = ElectricalSeriesWidget(
            nwb_file.acquisition["ElectricalSeries"],
            foreign_time_window_controller=time_trace_window_controller,
        )
        text = "(d) Raw ECoG"
        ecog_label = widgets.HTML(value=f"<b><font size=4>{text}</b>")
        ecog = widgets.VBox([ecog_label, ecog_widget], layout=self.box_layout)

        brain_widget = HumanElectrodesPlotlyWidget(nwb_file.electrodes)
        text = "(c) Subject electrode locations"
        brain_label = widgets.HTML(value=f"<b><font size=4>{text}</b>")
        brain = widgets.VBox([brain_label, brain_widget],
                             layout=self.box_layout)

        tab1_hbox_header = widgets.HBox([time_trace_window_controller])

        tab1_row1_widgets = widgets.HBox(
            [skeleton, jointpos],
            layout=self.row_layout,
        )
        tab1_row2_widgets = widgets.HBox(
            [brain, ecog],
            layout=self.row_layout,
        )
        tab1 = widgets.VBox(
            [tab1_hbox_header, tab1_row1_widgets, tab1_row2_widgets],
            layout=self.box_layout,
        )
        return tab1
예제 #10
0
    def __init__(
        self,
        events: Events,
        position: Position,
        acquisition: ElectricalSeries = None,
        foreign_time_window_controller: StartAndDurationController = None,
    ):
        super().__init__()

        before_ft = widgets.FloatText(1.5,
                                      min=0,
                                      description="before (s)",
                                      layout=Layout(width="200px"))
        after_ft = widgets.FloatText(1.5,
                                     min=0,
                                     description="after (s)",
                                     layout=Layout(width="200px"))

        # Extract reach arm label from events, format to match key in Position
        # spatial series
        reach_arm = events.description
        reach_arm = map(lambda x: x.capitalize(), reach_arm.split("_"))
        reach_arm = list(reach_arm)
        reach_arm = "_".join(reach_arm)
        self.spatial_series = position.spatial_series[reach_arm]

        # if foreign_time_window_controller is None:
        self.tt = get_timeseries_tt(self.spatial_series,
                                    istart=self.spatial_series.starting_time)
        #     self.time_window_controller = StartAndDurationController(
        #         tmax=self.tt[-1],
        #         tmin=self.tt[0],
        #         start=0,
        #         duration=5
        #     )
        #     show_time_controller = True
        # else:
        #     self.time_window_controller = foreign_time_window_controller
        #     show_time_controller = False

        # Store events in object
        self.events = events.timestamps[:]

        self.controls = dict(after=after_ft, before=before_ft)

        out_fig = interactive_output(self.trials_psth, self.controls)
        # self.time_window_controller.observe(self.updated_time_range)

        # self.fig = go.FigureWidget()
        # self.ecog_psth(acquisition)
        # if show_time_controller:
        #     header_row = widgets.HBox([before_ft,
        #                                after_ft,
        #                                self.time_window_controller
        #                               ]
        #                              )
        # else:
        header_row = widgets.HBox([
            before_ft,
            after_ft,
        ])
        self.children = [header_row, out_fig]
예제 #11
0
    def set_out_fig(self):

        timeseries = self.controls["timeseries"].value
        time_window = self.controls["time_window"].value

        istart = timeseries_time_to_ind(timeseries, time_window[0])
        istop = timeseries_time_to_ind(timeseries, time_window[1])

        data, units = get_timeseries_in_units(timeseries, istart, istop)

        tt = get_timeseries_tt(timeseries, istart, istop)

        positions = self.timeseries.get_ancestor("Position")
        position_colors = {
            key: color
            for key, color in zip(POSITION_KEYS, DEFAULT_PLOTLY_COLORS)
        }
        data_dim = data.shape[1]
        subplot_titles = np.repeat(POSITION_KEYS, data_dim)
        if (len(data.shape) > 1) | len(POSITION_KEYS) > 1:
            self.out_fig = go.FigureWidget(
                make_subplots(rows=len(POSITION_KEYS),
                              cols=2,
                              subplot_titles=subplot_titles))
            self.out_fig["layout"].update(width=800, height=700)
            for k, key in enumerate(POSITION_KEYS):
                data = positions[key].data[:]
                color = position_colors[key]
                for i, (yy, xyz) in enumerate(zip(data.T, ("x", "y", "z"))):
                    self.out_fig.add_trace(
                        go.Scattergl(x=tt,
                                     y=yy,
                                     marker_color=color,
                                     showlegend=False),
                        row=k + 1,
                        col=i + 1,
                    )
                    if units:
                        yaxes_label = f"{xyz} ({units})"
                    else:
                        yaxes_label = xyz
                    self.out_fig.update_yaxes(title_text=yaxes_label,
                                              row=k + 1,
                                              col=i + 1)
                    self.out_fig.update_xaxes(showticklabels=False,
                                              row=k + 1,
                                              col=i + 1)
                self.out_fig["layout"]["annotations"][(k * data_dim) +
                                                      i]["text"] = f"{key}"
            self.out_fig.update_xaxes(showticklabels=True, row=k + 1, col=i)
            self.out_fig.update_xaxes(showticklabels=True,
                                      row=k + 1,
                                      col=i + 1)
            self.out_fig.update_xaxes(title_text="time (s)", row=k + 1, col=i)
            self.out_fig.update_xaxes(title_text="time (s)",
                                      row=k + 1,
                                      col=i + 1)

        else:
            self.out_fig = go.FigureWidget()
            self.out_fig.add_trace(go.Scatter(x=tt, y=data))
            self.out_fig.update_xaxes(title_text="time (s)")

        def on_change(change):
            time_window = self.controls["time_window"].value
            istart = timeseries_time_to_ind(timeseries, time_window[0])
            istop = timeseries_time_to_ind(timeseries, time_window[1])

            tt = get_timeseries_tt(timeseries, istart, istop)
            yy, units = get_timeseries_in_units(timeseries, istart, istop)

            with self.out_fig.batch_update():
                if len(yy.shape) == 1:
                    self.out_fig.data[0].x = tt
                    self.out_fig.data[0].y = yy
                else:
                    for k, key in enumerate(POSITION_KEYS):
                        data = positions[key]
                        yy, units = get_timeseries_in_units(
                            data, istart, istop)

                        for i, dd in enumerate(yy.T):
                            self.out_fig.data[(k * data_dim) + i].x = tt
                            self.out_fig.data[(k * data_dim) + i].y = dd

        self.controls["time_window"].observe(on_change)
예제 #12
0
        def update_traces(select_start_time, select_duration):

            ctx = dash.callback_context
            trigger_source = ctx.triggered[0]['prop_id'].split('.')[1]

            if not trigger_source:
                raise dash.exceptions.PreventUpdate

            time_window = [
                select_start_time, select_start_time + select_duration
            ]

            # Update electrophys trace
            timeseries = self.ecephys_trace
            istart = timeseries_time_to_ind(timeseries, time_window[0])
            istop = timeseries_time_to_ind(timeseries, time_window[1])
            yy, units = get_timeseries_in_units(timeseries, istart, istop)
            xx = get_timeseries_tt(timeseries, istart, istop)
            xrange0, xrange1 = min(xx), max(xx)
            self.traces.data[0].x = xx
            self.traces.data[0].y = list(yy)
            self.traces.update_layout(yaxis={
                "range": [min(yy), max(yy)],
                "autorange": False
            },
                                      xaxis={
                                          "range": [xrange0, xrange1],
                                          "autorange": False
                                      })

            # Update ophys trace
            timeseries = self.ophys_trace
            istart = timeseries_time_to_ind(timeseries, time_window[0])
            istop = timeseries_time_to_ind(timeseries, time_window[1])
            yy, units = get_timeseries_in_units(timeseries, istart, istop)
            xx = get_timeseries_tt(timeseries, istart, istop)
            self.traces.data[1].x = xx
            self.traces.data[1].y = list(yy)
            self.traces.update_layout(yaxis3={
                "range": [min(yy), max(yy)],
                "autorange": False
            },
                                      xaxis3={
                                          "range": [xrange0, xrange1],
                                          "autorange": False
                                      })

            # Update spikes traces
            self.update_spike_traces(time_window=time_window)
            self.traces.update_layout(xaxis2={
                "range": [xrange0, xrange1],
                "autorange": False
            })

            # Update frame trace
            self.start_frame_x = (xrange1 + xrange0) / 2
            self.traces.update_layout(shapes=[{
                'type': 'line',
                'x0': (xrange1 + xrange0) / 2,
                'x1': (xrange1 + xrange0) / 2,
                'xref': 'x',
                'y0': -1000,
                'y1': 1000,
                'yref': 'paper',
                'line': {
                    'width': 4,
                    'color': 'rgb(30, 30, 30)'
                }
            }])

            return {'display': "inline-block"}, self.traces