Beispiel #1
0
def setup_play_level(level, on_update, interval=1000, min=1, max=8):
    play = Play(interval=interval, min=min, max=max, step=1)
    progress = IntProgress(min=min, max=max, step=1)

    link((play, 'value'), (progress, 'value'))
    play.observe(on_update, 'value')

    canvas_dimensions = level.get_canvas_dimensions()
    layout = Layout(width=f'{canvas_dimensions[0]}px')
    return play, progress, layout
Beispiel #2
0
class DemoViewer:
    def __init__(self, storage, setup):
        self.storage = storage
        self.setup = setup

        self.nans = None

        self.play = Play()
        self.step_slider = IntSlider()
        self.fps_slider = IntSlider(min=100, max=1000, description="1000/fps")
        self.product_select = Select()
        self.plots_box = Box()

        self.slider = {}
        self.lines = {'x': [{}, {}], 'y': [{}, {}]}
        for xy in ('x', 'y'):
            self.slider[xy] = IntRangeSlider(min=0, max=1, description=f'spectrum_{xy}',
                                             orientation='horizontal' if xy == 'x' else 'vertical')

        self.reinit({})

    def clear(self):
        self.plots_box.children = ()

    def reinit(self, products):
        self.products = products
        self.product_select.options = [key for key, val in products.items() if len(val.shape) == 2]
        self.plots = {}
        for var in products.keys():
            self.plots[var] = Output()
        self.ims = {}
        self.axs = {}
        self.figs = {}
        for j, xy in enumerate(('x', 'y')):
            self.slider[xy].max = self.setup.grid[j]

        self.nans = np.full((self.setup.grid[0], self.setup.grid[1]), np.nan)  # TODO: np.nan

        for key in self.plots.keys():
            with self.plots[key]:
                clear_output()
                product = self.products[key]
                if len(product.shape) == 2:

                    data=self.nans
                    domain_size_in_metres=self.setup.size
                    cmap='YlGnBu'
                    fig, ax = plt.subplots(1, 1)
                    label = f"{product.description} [{product.unit}]"
                    scale = product.scale

                    ax.set_xlabel('X [m]')
                    ax.set_ylabel('Z [m]')
                    im = ax.imshow(_transform(data),
                                   origin='lower',
                                   extent=(0, domain_size_in_metres[0], 0, domain_size_in_metres[1]),
                                   cmap=cmap,
                                   norm=matplotlib.colors.LogNorm() if scale == 'log' and np.isfinite(
                                       data).all() else None
                                   )
                    plt.colorbar(im, ax=ax).set_label(label)
                    im.set_clim(vmin=product.range[0], vmax=product.range[1])

                    x = self.slider['x'].value[0] * self.setup.size[0] / self.setup.grid[0]
                    y = self.slider['y'].value[0] * self.setup.size[1] / self.setup.grid[1]
                    self.lines['x'][0][key] = ax.axvline(x=x, color='red')
                    self.lines['y'][0][key] = ax.axhline(y=y, color='red')
                    x = self.slider['x'].value[1] * self.setup.size[0]/self.setup.grid[0]
                    y = self.slider['y'].value[1] * self.setup.size[1]/self.setup.grid[1]
                    self.lines['x'][1][key] = ax.axvline(x=x, color='red')
                    self.lines['y'][1][key] = ax.axhline(y=y, color='red')
                elif len(product.shape) == 3:
                    fig, ax = plt.subplots(1, 1)
                    ax.set_xlim(np.amin(self.setup.v_bins), np.amax(self.setup.v_bins))
                    ax.set_ylim(0, 10)
                    ax.set_xlabel("TODO [TODO]")
                    ax.set_ylabel("TODO [TODO]")
                    ax.set_xscale('log')
                    ax.grid(True)
                    im = ax.step(self.setup.v_bins[:-1], np.full_like(self.setup.v_bins[:-1], np.nan))
                    im = im[0]
                else:
                    raise NotImplementedError()
                self.figs[key], self.ims[key], self.axs[key] = fig, im, ax
                plt.show()

        self.plot_box = Box()
        if len(products.keys()) > 0:
            self.plots_box.children = (
                HBox(children=(self.slider['y'], VBox((self.slider['x'], self.plot_box)))),
                self.plots['Particles Size Spectrum']
            )

        n_steps = len(self.setup.steps)
        self.step_slider.max = n_steps - 1
        self.play.max = n_steps - 1
        self.play.value = 0
        self.step_slider.value = 0
        self.replot()

    def replot(self, _=None):
        if self.product_select.value in self.plots:
            self.plot_box.children = [self.plots[self.product_select.value]]

        step = self.step_slider.value
        for key in self.plots.keys():
            try:
                data = self.storage.load(self.setup.steps[step], key)
            except self.storage.Exception:
                data = self.nans
            if len(self.products[key].shape) == 2:
                self.ims[key].set_data(_transform(data))
                self.axs[key].set_title(f"min:{np.amin(data):.4g}    max:{np.amax(data):.4g}    std:{np.std(data):.4g}")

                self.lines['x'][0][key].set_xdata(x=self.slider['x'].value[0] * self.setup.size[0]/self.setup.grid[0])
                self.lines['y'][0][key].set_ydata(y=self.slider['y'].value[0] * self.setup.size[1]/self.setup.grid[1])
                self.lines['x'][1][key].set_xdata(x=self.slider['x'].value[1] * self.setup.size[0]/self.setup.grid[0])
                self.lines['y'][1][key].set_ydata(y=self.slider['y'].value[1] * self.setup.size[1]/self.setup.grid[1])
            elif len(self.products[key].shape) == 3:
                xrange = slice(*self.slider['x'].value)
                yrange = slice(*self.slider['y'].value)
                data = data[xrange, yrange, :]
                data = np.mean(np.mean(data, axis=0), axis=0)
                self.ims[key].set_ydata(data)
                amax = np.amax(data)
                if np.isfinite(amax):
                    self.axs[key].set_ylim((0, amax))
            else:
                raise NotImplementedError()

        for key in self.plots.keys():
            with self.plots[key]:
                clear_output(wait=True)
                display(self.figs[key])

    def box(self):
        jslink((self.play, 'value'), (self.step_slider, 'value'))
        jslink((self.play, 'interval'), (self.fps_slider, 'value'))
        self.play.observe(self.replot, 'value')
        self.product_select.observe(self.replot, 'value')
        for xy in ('x', 'y'):
            self.slider[xy].observe(self.replot, 'value')
        return VBox([
            Box([self.play, self.step_slider, self.fps_slider]),
            self.product_select,
            self.plots_box
        ])
Beispiel #3
0
class DemoViewer:
    def __init__(self, storage, setup):
        self.storage = storage
        self.setup = setup

        self.nans = None

        self.play = Play()
        self.step_slider = IntSlider()
        self.fps_slider = IntSlider(min=100, max=1000, description="1000/fps")
        self.plots = {}
        for var in setup.output_vars:
            self.plots[var] = Output() 
        self.ims = {}

        self.reinit()

    def reinit(self):
        n_steps = len(self.setup.steps)
        self.step_slider.max = n_steps - 1
        self.play.max = n_steps - 1
        self.play.value = 0
        self.step_slider.value = 0
        self.clims = { # TODO : not here
            "m0": (0, 1e8,  'YlGnBu'),
            "th": (288, 295, 'Reds'),
            "qv": (0.005, .0075,  'Greens'),
            "RH": (.5, 1.1,   'GnBu'),
            "volume_m1": (1e-20, 1e-19, 'Reds')
        }

        self.nans = np.full((self.setup.grid[0], self.setup.grid[1]), np.nan) # TODO: np.nan
        for key in self.plots.keys():
            with self.plots[key]:
                clear_output()
                _, ax = plt.subplots(1, 1)
                self.ims[key] = plotter.image(ax, self.nans, self.setup.size, label=key, cmap=self.clims[key][2])
                self.ims[key].set_clim(vmin = self.clims[key][0], vmax = self.clims[key][1])
                plt.show()

    def replot(self, bunch):
        step = bunch.new

        for key in self.plots.keys():
            try:
                data = self.storage.load(self.setup.steps[step], key)
            except self.storage.Exception:
                data = self.nans
            plotter.image_update(self.ims[key], data)

        for key in self.plots.keys():
            with self.plots[key]:
                clear_output(wait=True)
                display(self.ims[key].figure)

    def box(self):
        jslink((self.play, 'value'), (self.step_slider, 'value'))
        jslink((self.play, 'interval'), (self.fps_slider, 'value'))
        self.play.observe(self.replot, 'value')
        return VBox([
            Box([self.play, self.step_slider, self.fps_slider]),
            Box(
                children=tuple(self.plots.values()),
                layout=Layout(display='flex', flex_flow='column')
            )
        ])
Beispiel #4
0
class DemoViewer:
    def __init__(self, storage, setup):
        self.storage = storage
        self.setup = setup

        self.nans = None

        self.play = Play()
        self.step_slider = IntSlider()
        self.fps_slider = IntSlider(min=100, max=1000, description="1000/fps")

        self.plots = {}
        self.plots_box = Box(
            children=tuple(self.plots.values()),
            layout=Layout(display='flex', flex_flow='column')
        )

        self.reinit({})

    def clear(self):
        self.plots_box.children = ()

    def reinit(self, products):
        self.products = products

        self.plots.clear()
        for var in products.keys():
            self.plots[var] = Output()
        self.ims = {}
        self.axs = {}

        self.nans = np.full((self.setup.grid[0], self.setup.grid[1]), np.nan)  # TODO: np.nan
        for key in self.plots.keys():
            with self.plots[key]:
                clear_output()
                _, ax = plt.subplots(1, 1)
                product = self.products[key]
                self.ims[key], self.axs[key] = plotter.image(ax, self.nans, self.setup.size,
                                              label=f"{product.description} [{product.unit}]",
                                              # cmap=self.clims[key][2], # TODO: Reds, Blues, YlGnBu...
                                              scale=product.scale
                                              )
                self.ims[key].set_clim(vmin=product.range[0], vmax=product.range[1])
                plt.show()

        self.plots_box.children = tuple(self.plots.values())
        n_steps = len(self.setup.steps)
        self.step_slider.max = n_steps - 1
        self.play.max = n_steps - 1
        self.play.value = 0
        self.step_slider.value = 0
        self.replot(step=0)

    def handle_replot(self, bunch):
        self.replot(bunch.new)

    def replot(self, step):
        for key in self.plots.keys():
            try:
                data = self.storage.load(self.setup.steps[step], key)
            except self.storage.Exception:
                data = self.nans
            plotter.image_update(self.ims[key], self.axs[key], data)

        for key in self.plots.keys():
            with self.plots[key]:
                clear_output(wait=True)
                display(self.ims[key].figure)

    def box(self):
        jslink((self.play, 'value'), (self.step_slider, 'value'))
        jslink((self.play, 'interval'), (self.fps_slider, 'value'))
        self.play.observe(self.handle_replot, 'value')
        return VBox([
            Box([self.play, self.step_slider, self.fps_slider]),
            self.plots_box
        ])
Beispiel #5
0
class SkeletonPlot(widgets.VBox):
    def __init__(
        self,
        position: Position,
        foreign_time_window_controller: StartAndDurationController = None,
    ):
        super().__init__()

        self.position = position
        joint_keys = list(position.spatial_series.keys())
        self.joint_colors = []
        for (joint, c) in zip(joint_keys, DEFAULT_PLOTLY_COLORS):
            self.joint_colors.append(c)

        self.spatial_series = position.spatial_series[joint_keys[0]]
        if foreign_time_window_controller is None:
            self.time_window_controller = StartAndDurationController(
                tmax=get_timeseries_maxt(self.spatial_series),
                tmin=get_timeseries_mint(self.spatial_series),
                start=0,
                duration=5,
            )
            show_time_controller = True
        else:
            show_time_controller = False
            self.time_window_controller = foreign_time_window_controller
        frame_ind = timeseries_time_to_ind(
            self.spatial_series, self.time_window_controller.value[0])

        self.sample_period = 1 / list(
            self.position.spatial_series.values())[0].rate
        self.play = Play(
            value=0,
            min=0,
            max=int(self.time_window_controller.duration.value /
                    self.sample_period),
            step=1,
            interval=self.sample_period * 1000,
        )

        joint_colors = [
            to_hex(np.array(unlabel_rgb(x)) / 255)
            for x in DEFAULT_PLOTLY_COLORS
        ]
        self.joint_keys = POSITION_KEYS
        self.joint_colors = [
            joint_colors[0],  # l_wrist
            joint_colors[1],  # l_elbow
            joint_colors[2],  # l_shoulder
            joint_colors[9],  # neck
            joint_colors[4],  # nose
            joint_colors[3],  # l_ear
            joint_colors[5],  # r_ear
            joint_colors[4],  # nose
            joint_colors[9],  # neck
            joint_colors[6],  # r_shoulder
            joint_colors[7],  # r_elbow
            joint_colors[8],  # r_wrist
        ]
        self.skeleton_labels = [
            "L_Wrist",
            "L_Elbow",
            "L_Shoulder",
            "Neck",
            "Nose",
            "L_Ear",
            "R_Ear",
            "Nose",
            "Neck",
            "R_Shoulder",
            "R_Elbow",
            "R_Wrist",
        ]

        self.fig = (go.FigureWidget()
                    )  # animation_duration=int(1/spatial_series.rate*1000)
        self.plot_skeleton(frame_ind)
        self.updated_time_range({"new": None})

        # Updates list of valid spike times at each change in time range
        self.time_window_controller.observe(self.updated_time_range)
        self.play.observe(self.animate_scatter_chart)

        if show_time_controller:
            self.children = [self.time_window_controller, self.fig, self.play]
        else:
            self.children = [self.fig, self.play]

    def updated_time_range(self, change=None):
        """Operations to run whenever time range gets updated"""
        if "new" in change:

            self.frame_ind_start = timeseries_time_to_ind(
                self.spatial_series,
                self.time_window_controller.value[0],
            )

            if self.frame_ind_start is np.nan:
                return print("No data present")

            self.frame_ind_end = timeseries_time_to_ind(
                self.spatial_series, self.time_window_controller.value[1])
            if (self.frame_ind_start > self.play.max
                ):  # make sure min always < max, otherwise throws error
                self.play.max, self.play.min = self.frame_ind_end, self.frame_ind_start
            else:
                self.play.min, self.play.max = self.frame_ind_start, self.frame_ind_end
            self.play.value = self.frame_ind_start

            # all_pos = np.vstack(
            #     [
            #         x.data[self.frame_ind_start : self.frame_ind_end]
            #         for x in self.position.spatial_series.values()
            #     ]
            # )

            # if not np.all(np.isnan(all_pos)):
            # self.fig.axes[0].scale.min = np.nanmin(all_pos[:, 0])
            # self.fig.axes[0].scale.max = np.nanmax(all_pos[:, 0]) + 20
            #
            # self.fig.axes[1].scale.max = np.nanmin(all_pos[:, 1])
            # self.fig.axes[1].scale.min = np.nanmax(all_pos[:, 1])

            skeleton_vector = []
            for joint in self.joint_keys:
                skeleton_vector.append(
                    self.position[joint].data[self.frame_ind_start])
            skeleton_vector = np.vstack(skeleton_vector)
            skeleton_vector = self.calc_centroid(skeleton_vector)

            with self.fig.batch_update():
                self.fig.update_traces(x=skeleton_vector[:, 0],
                                       y=-skeleton_vector[:, 1])

    def animate_scatter_chart(self, change=None):
        if change["name"] == "value":

            frame_ind = change["new"]

            skeleton_vector = []
            for joint in self.joint_keys:
                skeleton_vector.append(self.position[joint].data[frame_ind])
            skeleton_vector = np.vstack(skeleton_vector)
            skeleton_vector = self.calc_centroid(skeleton_vector)
            with self.fig.batch_update():
                self.fig.update_traces(x=skeleton_vector[:, 0],
                                       y=-skeleton_vector[:, 1])

    def calc_centroid(self, skeleton_vector):
        base_of_neck = (skeleton_vector[2, :] + skeleton_vector[6, :]) / 2
        new_skeleton_vector = np.vstack([
            skeleton_vector[0:3, :],
            base_of_neck,
            skeleton_vector[4, :],  # nose
            skeleton_vector[3, :],  # left ear
            skeleton_vector[5, :],  # right ear
            skeleton_vector[4, :],  # nose
            base_of_neck,
            skeleton_vector[6:],
        ])

        return new_skeleton_vector

    def plot_skeleton(self, frame_ind):

        skeleton_vector = []
        for joint in self.joint_keys:
            skeleton_vector.append(self.position[joint].data[frame_ind])

        skeleton_vector = np.vstack(skeleton_vector)
        skeleton_vector = self.calc_centroid(skeleton_vector)

        self.fig.add_trace(
            go.Scatter(
                x=skeleton_vector[:, 0],
                y=-skeleton_vector[:, 1],
                mode="lines+markers+text",
                marker_color=self.joint_colors,
                marker_size=12,
                text=self.skeleton_labels,
                hoverinfo="text",
                textposition="bottom center",
            ))

        self.fig.update_layout(
            height=500,
            width=600,
            xaxis=dict(
                showgrid=False,  # thin lines in the background
                zeroline=False,  # thick line at x=0
                visible=False,  # numbers below
            ),
            yaxis=dict(
                showgrid=False,  # thin lines in the background
                zeroline=False,  # thick line at x=0
                visible=False,  # numbers below
            ),
        )