示例#1
0
 def removewave(self, axes: Axes, lines: Line2D):
     if lines is not None:
         lines.pop(0).remove()
     axes.relim()
     axes.autoscale_view(True, True, True)
     axes.clear()
     self.canvas.draw()
示例#2
0
def _update_window(
    render: Callable[[Axes, Any], None],
    queue: Queue,
    ax: Axes,
    delay: int,
    auto_play: bool,
    clear_axes: bool,
) -> None:
    while queue.empty():
        plt.pause(0.005)
    val = queue.get_nowait()
    while val is not None:
        if clear_axes:
            ax.clear()
        render(ax, val)
        # Inefficient, but convenient for users
        ax.figure.canvas.draw()
        if not auto_play:
            input('Press ENTER to continue...')
        if delay > 0:
            plt.pause(delay / 1000)
        else:
            plt.pause(0.001)
        while queue.empty():
            plt.pause(0.005)
        val = queue.get_nowait()
示例#3
0
def _replot_ax(ax: Axes, freq, kwargs):
    data = getattr(ax, "_plot_data", None)

    # clear current axes and data
    ax._plot_data = []
    ax.clear()

    decorate_axes(ax, freq, kwargs)

    lines = []
    labels = []
    if data is not None:
        for series, plotf, kwds in data:
            series = series.copy()
            idx = series.index.asfreq(freq, how="S")
            series.index = idx
            ax._plot_data.append((series, plotf, kwds))

            # for tsplot
            if isinstance(plotf, str):
                from pandas.plotting._matplotlib import PLOT_CLASSES

                plotf = PLOT_CLASSES[plotf]._plot

            lines.append(
                plotf(ax, series.index._mpl_repr(), series.values, **kwds)[0])
            labels.append(pprint_thing(series.name))

    return lines, labels
示例#4
0
def build_wind_barbs(date: xarray.DataArray, ds: xarray.Dataset, ax: Axes,
                     manifest: dict):

    ax.clear()

    # capture date and convert to datetime
    dt = datetime64_to_datetime(date)

    # set title on figure
    ax.set_title(dt.isoformat())

    # get a subset of data points since we don't want to display a wind barb at every point
    windx_values = ds.sel(time=date)['mesh2d_windx'][::100].values
    windy_values = ds.sel(time=date)['mesh2d_windy'][::100].values
    facex_values = ds.mesh2d_face_x[::100].values
    facey_values = ds.mesh2d_face_y[::100].values

    #
    # plot barbs
    #

    ax.barbs(facex_values, facey_values, windx_values, windy_values)

    #
    # generate geojson
    #

    wind_speeds = np.abs(np.hypot(windx_values, windy_values))
    wind_directions = np.arctan2(windx_values, windy_values)
    coords = np.column_stack([facex_values, facey_values])
    points = [Point(coord.tolist()) for idx, coord in enumerate(coords)]
    features = [
        Feature(geometry=wind_point,
                properties={
                    'speed': wind_speeds[idx],
                    'direction': wind_directions[idx]
                }) for idx, wind_point in enumerate(points)
    ]
    wind_geojson = FeatureCollection(features=features)

    # create output directory if it doesn't exist
    output_path = '/tmp/wind'
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    file_name = '{}.json'.format(dt.isoformat())

    # save output file
    json.dump(wind_geojson, open(os.path.join(output_path, file_name), 'w'))

    # update manifest
    if 'wind' not in manifest:
        manifest['wind'] = {'geojson': []}
    manifest['wind']['geojson'].append({
        'date': dt.isoformat(),
        'path': os.path.join('wind', file_name),
    })
示例#5
0
    def plot(self, subplot: Axes) -> None:
        """
        Draws the content of the depiction in the subplot.

        :param subplot: The subplot where the content should be drawn.
        """
        subplot.clear()
        self.scatter_plot(subplot)
        if self.show_badge:
            self.plot_badge(subplot)
 def prev(self, event):
     if self.ind > 0:
         self.ind -= 1
     else:
         self.ind = len(combinations) - 1
     cur_scatter_data = calc_cur_combination_coord(self.ind)
     fig_legend.remove()
     Axes.clear(ax)
     new_scatters(cur_scatter_data)
     plt.draw()
 def next(self, event):
     if self.ind <= len(combinations) - 1:
         self.ind += 1
     else:
         self.ind = 0
     cur_scatter_data = calc_cur_combination_coord(self.ind)
     fig_legend.remove()
     Axes.clear(ax)
     new_scatters(cur_scatter_data)
     plt.draw()
 def next_col(self, event):
     if self.ind_col + 1 <= len(column_id) - 1:
         self.ind_col += 1
         self.ind = column_id[self.ind_col]
     else:
         self.ind_col = 0
         self.ind = column_id[self.ind_col]
     cur_scatter_data = calc_cur_combination_coord(self.ind)
     fig_legend.remove()
     Axes.clear(ax)
     new_scatters(cur_scatter_data)
     plt.draw()
 def next(self, event):
     if self.ind < len(combinations) - 1:
         self.ind += 1
         self.ind_col = np.digitize(self.ind, column_id) - 1
     else:
         self.ind = 0
         self.ind_col = np.digitize(self.ind, column_id) - 1
     cur_scatter_data = calc_cur_combination_coord(self.ind)
     fig_legend.remove()
     Axes.clear(ax)
     new_scatters(cur_scatter_data)
     plt.draw()
示例#10
0
def plot(ax: Axes,
         gridmap: OccupancyGrid,
         planned: Optional[Path] = None,
         executed: Optional[Path] = None):
    ax.clear()
    if gridmap.data is not None:
        gridmap.plot(ax)
    if planned and planned.poses:
        planned.plot(ax)
    if executed and executed.poses:
        executed.plot(ax, line=True)
    ax.set(xlabel="x[m]", ylabel="y[m]")
    plt.ylabel('y[m]')
    ax.set_aspect('equal', 'box')
示例#11
0
def test_step(ode_model: VanillaODEFunc, loss_function: Callable, y_0: torch.tensor, y_true: torch.tensor,
              t: torch.tensor, epoch: int, output_ax: Axes, output_dir: Path) -> None:
    """
    Plot ODE output based on a given initial point (y_0) and known system dynamics (y_true). Save figure.
    @param ode_model: Torch Module used to parameterise the ODE system.
    @param loss_function: Torch loss function to use.
    @param y_0: Initial condition of system.
    @param y_true: True state of system at each timestep.
    @param t: Tensor of timesteps.
    @param epoch: Number of current epoch.
    @param output_ax: Axes instance to plot trajectory on.
    @param output_dir: Directory to save figures to.
    """
    with torch.no_grad():
        y_pred = odeint(ode_model, y_0, t)
        loss = loss_function(y_pred, y_true)
        output_ax.clear()
        output_ax.plot(y_pred[:, 0, 0], y_pred[:, 0, 1], 'r', y_true[:, 0, 0], y_true[:, 0, 1], 'b')
        plt.draw()
        plt.savefig(output_dir / f"iter_{epoch}.png")
        plt.pause(0.001)

        print(f"epoch: {epoch}, Loss: {loss}")
示例#12
0
    def plot(self, subplot: Axes) -> None:
        """
        Draws the content of the depiction in the subplot.

        :param subplot: The subplot where the content should be drawn.
        """
        subplot.clear()
        self.input_data.scatter_plot(subplot)
        self.contour_plot(subplot)
        if self.score is not None:
            score_txt = ('Acc: %.2f' % self.score).lstrip('0')
            self.draw_text(subplot, score_txt,
                           position=(0.96, 0.03), ha='right', va='bottom')
        if self.loss is not None:
            loss_txt = ('Loss: %.2f' % self.loss).lstrip('0')
            self.draw_text(subplot, loss_txt,
                           position=(0.04, 0.03), ha='left', va='bottom')
        if self.lr is not None:
            lr_txt = ('LR: %.4f' % self.lr).lstrip('0')
            self.draw_text(subplot, lr_txt,
                           position=(0.96, 0.96), ha='right', va='top')
        if self.input_data.show_badge:
            self.input_data.plot_badge(subplot)
示例#13
0
def build_geojson_contours(data, ax: Axes, manifest: dict):

    ax.clear()

    z = data
    x = z.mesh2d_face_x[:len(z)]
    y = z.mesh2d_face_y[:len(z)]

    variable_name = z.name

    # capture date and convert to datetime
    dt = datetime64_to_datetime(z.time)

    # set title on figure
    ax.set_title(dt.isoformat())

    # build json file name output
    file_name = '{}.json'.format(dt.isoformat())

    # convert to numpy arrays
    z = z.values
    x = x.values
    y = y.values

    # build grid constraints
    xi = np.linspace(np.floor(x.min()), np.ceil(x.max()), GRID_SIZE)
    yi = np.linspace(np.floor(y.min()), np.ceil(y.max()), GRID_SIZE)

    # build delaunay triangles
    triang = tri.Triangulation(x, y)

    # build a list of the triangle coordinates
    tri_coords = []
    for i in range(len(triang.triangles)):
        tri_coords.append(
            tuple(zip(x[triang.triangles[i]], y[triang.triangles[i]])))

    # filter out large triangles
    large_triangles = [
        i for i, t in enumerate(tri_coords)
        if circum_radius(*t) > MAX_CIRCUM_RADIUS
    ]
    mask = [i in large_triangles for i, _ in enumerate(triang.triangles)]
    triang.set_mask(mask)

    # interpolate values from triangle data and build a mesh of data
    interpolator = tri.LinearTriInterpolator(triang, z)
    Xi, Yi = np.meshgrid(xi, yi)
    zi = interpolator(Xi, Yi)

    contourf = ax.contourf(xi, yi, zi, LEVELS, cmap=plt.cm.jet)

    # create output directory if it doesn't exist
    output_path = '/tmp/{}'.format(variable_name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    # convert matplotlib contourf to geojson
    geojsoncontour.contourf_to_geojson(
        contourf=contourf,
        min_angle_deg=3.0,
        ndigits=5,
        stroke_width=2,
        fill_opacity=0.5,
        geojson_filepath=os.path.join(output_path, file_name),
    )

    # update the manifest with the geojson output
    manifest_entry = {
        'date': dt.isoformat(),
        'path': os.path.join(variable_name, file_name),
    }
    if variable_name not in manifest:
        manifest[variable_name] = {'geojson': []}
    manifest[variable_name]['geojson'].append(manifest_entry)

    return contourf