예제 #1
0
def _annotate_points(axis, points: math.Tensor, labelled_dim: math.Shape):
    if points.shape['vector'].size == 2:
        x, y = math.reshaped_native(points, ['vector', points.shape.without('vector')], to_numpy=True, force_expand=True)
        if labelled_dim.item_names[0]:
            x_view = axis.get_xlim()[1] - axis.get_xlim()[0]
            y_view = axis.get_ylim()[1] - axis.get_ylim()[0]
            for x_, y_, label in zip(x, y, labelled_dim.item_names[0]):
                axis.annotate(label, (x_ + .01 * x_view, y_ + .01 * y_view))
예제 #2
0
def discretize(grid: Grid, filled_fraction=0.25):
    """ Treats channel dimensions as batch dimensions. """
    import numpy as np
    data = math.reshaped_native(grid.values,
                                [grid.shape.non_spatial, grid.shape.spatial])
    ranked_idx = np.argsort(data, axis=-1)
    filled_idx = ranked_idx[:,
                            int(
                                round(grid.shape.spatial.volume *
                                      (1 - filled_fraction))):]
    filled = np.zeros_like(data)
    np.put_along_axis(filled, filled_idx, 1, axis=-1)
    filled_t = math.reshaped_tensor(
        filled, [grid.shape.non_spatial, grid.shape.spatial])
    return grid.with_values(filled_t)
예제 #3
0
 def test_reshaped_native(self):
     a = math.random_uniform(channel(vector=2) & spatial(x=4, y=3))
     nat = math.reshaped_native(a, ['batch', a.shape.spatial, 'vector'], force_expand=False)
     self.assertEqual((1, 12, 2), nat.shape)
     nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=False)
     self.assertEqual((1, 12, 2), nat.shape)
     nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=['x'])
     self.assertEqual((1, 12, 2), nat.shape)
     nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=True)
     self.assertEqual((10, 12, 2), nat.shape)
     nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=['batch'])
     self.assertEqual((10, 12, 2), nat.shape)
     nat = math.reshaped_native(a, [a.shape.spatial, channel(vector=2, v2=2)], force_expand=False)
     self.assertEqual((12, 4), nat.shape)
     try:
         math.reshaped_native(a, [channel(vector=2, v2=2)], force_expand=False)
     except AssertionError as err:
         print(err)
         pass
예제 #4
0
def _plot(data: SampledField,
          fig: graph_objects.Figure,
          size: tuple,
          colormap: str or None,
          show_color_bar: bool,
          row: int = None, col: int = None,
          ):
    subplot = fig.get_subplot(row, col)
    vector = data.points.shape['vector']
    if data.spatial_rank == 1 and isinstance(data, Grid):
        x = data.points.vector[0].numpy().flatten()
        channels = data.values.shape.channel
        if channels.rank == 1 and channels.get_item_names(0) is not None:
            for i, name in enumerate(channels.get_item_names(0)):
                y = math.reshaped_native(real_values(data[{channels.name: i}]), [data.shape.spatial], to_numpy=True)
                fig.add_trace(graph_objects.Scatter(x=x, y=y, mode='lines+markers', name=name), row=row, col=col)
            fig.update_layout(showlegend=True)
        else:
            for ch_idx in channels.meshgrid():
                y = math.reshaped_native(real_values(data[ch_idx]), [data.shape.spatial], to_numpy=True)
                fig.add_trace(graph_objects.Scatter(x=x, y=y, mode='lines+markers', name='Multi-channel'), row=row, col=col)
            fig.update_layout(showlegend=False)
    elif data.spatial_rank == 2 and isinstance(data, Grid) and 'vector' not in data.shape:  # heatmap
        dims = spatial(data)
        values = real_values(data).numpy(dims.reversed)
        x = data.points.vector[dims[0].name].dimension(dims[1].name)[0].numpy()
        y = data.points.vector[dims[1].name].dimension(dims[0].name)[0].numpy()
        min_val, max_val = numpy.nanmin(values), numpy.nanmax(values)
        min_val, max_val = min_val if numpy.isfinite(min_val) else 0, max_val if numpy.isfinite(max_val) else 0
        color_scale = get_div_map(min_val, max_val, equal_scale=True, colormap=colormap)
        # color_bar = graph_objects.heatmap.ColorBar(x=1.15)   , colorbar=color_bar
        fig.add_heatmap(row=row, col=col, x=x, y=y, z=values, zauto=False, zmin=min_val, zmax=max_val, colorscale=color_scale, showscale=show_color_bar)
        subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain', title=dims.names[0])
        subplot.yaxis.update(constrain='domain', title=dims.names[1])
    elif data.spatial_rank == 2 and isinstance(data, Grid):  # vector field
        if isinstance(data, StaggeredGrid):
            data = data.at_centers()
        x, y = math.reshaped_native(data.points.vector[spatial(data)], [vector, data.shape.without(vector)], to_numpy=True, force_expand=True)
        extra_channels = data.shape.channel.without('vector')
        data_x, data_y = math.reshaped_native(data.values, [vector, extra_channels, spatial(data)], to_numpy=True, force_expand=True)
        lower_x, lower_y = [float(l) for l in data.bounds.lower.vector.unstack_spatial('x,y')]
        upper_x, upper_y = [float(u) for u in data.bounds.upper.vector.unstack_spatial('x,y')]
        x_range = [lower_x, upper_x]
        y_range = [lower_y, upper_y]
        for ch in range(data_x.shape[0]):
            # quiver = figure_factory.create_quiver(x, y, data_x[ch], data_y[ch], scale=1.0)  # 7 points per arrow
            # fig.add_trace(quiver, row=row, col=col)
            data_y_flat = data_y[ch].flatten()
            data_x_flat = data_x[ch].flatten()
            # lines_y = numpy.stack([y, y + data_y_flat, [None] * len(x)], -1).flatten()  # 3 points per arrow
            # lines_x = numpy.stack([x, x + data_x_flat, [None] * len(x)], -1).flatten()
            lines_y = numpy.stack([y - data_y_flat / 2, y + data_y_flat / 2, [None] * len(x)], -1).flatten()  # 3 points per arrow
            lines_x = numpy.stack([x - data_x_flat / 2, x + data_x_flat / 2, [None] * len(x)], -1).flatten()
            name = extra_channels.get_item_names(0)[ch] if extra_channels.rank == 1 and extra_channels.get_item_names(0) is not None else None
            fig.add_scatter(x=lines_x, y=lines_y, mode='lines', row=row, col=col, name=name)
        if data_x.shape[0] == 1:
            fig.update_layout(showlegend=False)
        subplot.xaxis.update(range=x_range)
        subplot.yaxis.update(range=y_range)
        subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain')
        subplot.yaxis.update(constrain='domain')
    elif data.spatial_rank == 3 and isinstance(data, Grid) and data.shape.channel.volume == 1:  # 3D heatmap
        values = real_values(data).numpy('z,y,x')
        x = data.points.vector['x'].numpy('z,y,x')
        y = data.points.vector['y'].numpy('z,y,x')
        z = data.points.vector['z'].numpy('z,y,x')
        min_val, max_val = numpy.nanmin(values), numpy.nanmax(values)
        min_val, max_val = min_val if numpy.isfinite(min_val) else 0, max_val if numpy.isfinite(max_val) else 0
        color_scale = get_div_map(min_val, max_val, equal_scale=True, colormap=colormap)
        fig.add_volume(x=x.flatten(), y=y.flatten(), z=z.flatten(), value=values.flatten(),
                       showscale=show_color_bar, colorscale=color_scale, cmin=min_val, cmax=max_val, cauto=False,
                       isomin=0.1, isomax=0.8,
                       opacity=0.1,  # needs to be small to see through all surfaces
                       surface_count=17,  # needs to be a large number for good volume rendering
                       row=row, col=col)
        fig.update_layout(uirevision=True)
    elif data.spatial_rank == 3 and isinstance(data, Grid):  # 3D vector field
        if isinstance(data, StaggeredGrid):
            data = data.at_centers()
        u = real_values(data).vector['x'].numpy('z,y,x')
        v = real_values(data).vector['y'].numpy('z,y,x')
        w = real_values(data).vector['z'].numpy('z,y,x')
        x = data.points.vector['x'].numpy('z,y,x')
        y = data.points.vector['y'].numpy('z,y,x')
        z = data.points.vector['z'].numpy('z,y,x')
        fig.add_cone(x=x.flatten(), y=y.flatten(), z=z.flatten(), u=u.flatten(), v=v.flatten(), w=w.flatten(),
                     colorscale='Blues',
                     sizemode="absolute", sizeref=1,
                     row=row, col=col)
    elif isinstance(data, PointCloud) and data.spatial_rank == 2 and 'vector' in channel(data):
        x, y = math.reshaped_native(data.points, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True)
        u, v = math.reshaped_native(data.values, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True)
        lower_x, lower_y = [float(d) for d in data.bounds.lower.vector]
        upper_x, upper_y = [float(d) for d in data.bounds.upper.vector]
        subplot.xaxis.update(range=[lower_x, upper_x])
        subplot.yaxis.update(range=[lower_y, upper_y])
        quiver = figure_factory.create_quiver(x, y, u, v, scale=1.0).data[0]  # 7 points per arrow
        if data.color.shape:
            # color = data.color.numpy(data.shape.non_channel).reshape(-1)
            warnings.warn("Multi-colored vector plots not yet supported")
        else:
            color = data.color.native()
            quiver.line.update(color=color)
        fig.add_trace(quiver, row=row, col=col)
        if data.points.vector.item_names:
            subplot.xaxis.update(title=data.points.vector.item_names[0])
            subplot.yaxis.update(title=data.points.vector.item_names[1])
        subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain')
        subplot.yaxis.update(constrain='domain')
    elif isinstance(data, PointCloud) and data.spatial_rank == 2:
        lower_x, lower_y = [float(d) for d in data.bounds.lower.vector]
        upper_x, upper_y = [float(d) for d in data.bounds.upper.vector]
        if data.points.shape.non_channel.rank > 1:
            data_list = field.unstack(data, data.points.shape.non_channel[0].name)
            for d in data_list:
                _plot(d, fig, size, colormap, show_color_bar, row, col)
        else:
            x, y = [d.numpy() for d in data.points.vector.unstack_spatial('x,y')]
            color = data.color.native()
            subplot_height = (subplot.yaxis.domain[1] - subplot.yaxis.domain[0]) * size[1]
            if isinstance(data.elements, Sphere):
                symbol = 'circle'
                marker_size = data.elements.bounding_radius().numpy() * 1.9
            elif isinstance(data.elements, BaseBox):
                symbol = 'square'
                marker_size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 1
            elif isinstance(data.elements, Point):
                symbol = 'x'
                marker_size = 12 / (subplot_height / (upper_y - lower_y))
            else:
                symbol = 'asterisk'
                marker_size = data.elements.bounding_radius().numpy()
            marker_size *= subplot_height / (upper_y - lower_y)
            marker = graph_objects.scatter.Marker(size=marker_size, color=color, sizemode='diameter', symbol=symbol)
            fig.add_scatter(mode='markers', x=x, y=y, marker=marker, row=row, col=col)
        subplot.xaxis.update(range=[lower_x, upper_x])
        subplot.yaxis.update(range=[lower_y, upper_y])
        fig.update_layout(showlegend=False)
        subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain')
        subplot.yaxis.update(constrain='domain')
    elif isinstance(data, PointCloud) and data.spatial_rank == 3:
        lower_x, lower_y, lower_z = [float(d) for d in data.bounds.lower.vector.unstack_spatial('x,y,z')]
        upper_x, upper_y, upper_z = [float(d) for d in data.bounds.upper.vector.unstack_spatial('x,y,z')]
        if data.points.shape.non_channel.rank > 1:
            data_list = field.unstack(data, data.points.shape.non_channel[0].name)
            for d in data_list:
                _plot(d, fig, size, colormap, show_color_bar, row, col)
        else:
            x, y, z = [d.numpy() for d in data.points.vector.unstack_spatial('x,y,z')]
            color = data.color.native()
            # if data.color.shape.instance_rank == 0:
            #     color = str(data.color)
            # else:
            #     color = [str(d) for d in math.unstack(data.color, instance)]
            domain_y = fig.layout[subplot.plotly_name].domain.y
            if isinstance(data.elements, Sphere):
                symbol = 'circle'
                marker_size = data.elements.bounding_radius().numpy() * 2
            elif isinstance(data.elements, BaseBox):
                symbol = 'square'
                marker_size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 1
            elif isinstance(data.elements, Point):
                symbol = 'x'
                marker_size = 4 / (size[1] * (domain_y[1] - domain_y[0]) / (upper_y - lower_y) * 0.5)
            else:
                symbol = 'asterisk'
                marker_size = data.elements.bounding_radius().numpy()
            marker_size *= size[1] * (domain_y[1] - domain_y[0]) / (upper_y - lower_y) * 0.5
            marker = graph_objects.scatter3d.Marker(size=marker_size, color=color, sizemode='diameter', symbol=symbol)
            fig.add_scatter3d(mode='markers', x=x, y=y, z=z, marker=marker, row=row, col=col)
        subplot.xaxis.update(range=[lower_x, upper_x])
        subplot.yaxis.update(range=[lower_y, upper_y])
        subplot.zaxis.update(range=[lower_z, upper_z])
        fig.update_layout(showlegend=False)
    else:
        raise NotImplementedError(f"No figure recipe for {data}")
예제 #5
0
def _plot(axis, data, show_color_bar, vmin, vmax, **plt_args):
    if isinstance(data, Grid) and data.spatial_rank == 1:
        x = data.points.staggered_direction[0].vector[0].numpy()
        requires_legend = False
        for c in channel(data).meshgrid(names=True):
            label = ", ".join([i for dim, i in c.items() if isinstance(i, str)])
            values = data.values[c].numpy()
            if values.dtype in (np.complex64, np.complex128):
                axis.plot(x, values.real, label=f"real({label})" if label else "real")
                axis.plot(x, values.imag, label=f"imag({label})" if label else "real")
                requires_legend = True
            else:
                axis.plot(x, values, label=label)
                requires_legend = requires_legend or label
        if requires_legend:
            axis.legend()
    elif isinstance(data, Grid) and channel(data).volume == 1 and data.spatial_rank == 2:
        dims = spatial(data)
        if data.bounds.upper.vector.item_names is not None:
            left, bottom = data.bounds.lower.vector[dims]
            right, top = data.bounds.upper.vector[dims]
        else:
            dim_indices = data.resolution.indices(dims)
            left, bottom = data.bounds.lower.vector[dim_indices]
            right, top = data.bounds.upper.vector[dim_indices]
        extent = (float(left), float(right), float(bottom), float(top))
        im = axis.imshow(data.values.numpy(dims.reversed), origin='lower', extent=extent, vmin=vmin, vmax=vmax, **plt_args)
        if show_color_bar:
            axis.figure.colorbar(im, ax=axis)  # adds a new Axis to the figure
        axis.set_xlabel(dims.names[0])
        axis.set_ylabel(dims.names[1])
    elif isinstance(data, Grid) and data.spatial_rank == 2:  # vector field
        if isinstance(data, StaggeredGrid):
            data = data.at_centers()
        x, y = [d.numpy('x,y') for d in data.points.vector.unstack_spatial('x,y')]
        u, v = [d.numpy('x,y') for d in data.values.vector.unstack_spatial('x,y')]
        color = axis.xaxis.label.get_color()
        axis.quiver(x, y, u, v, color=color, units='xy', scale=1)
        axis.set_aspect('equal', adjustable='box')
    elif isinstance(data, Grid) and channel(data).volume > 1 and data.spatial_rank == 3:
        x, y, z = [d.numpy('x,y,z') for d in data.points.vector.unstack_spatial('x,y,z')]
        u, v, w = [d.numpy('x,y,z') for d in data.values.vector.unstack_spatial('x,y,z')]
        axis.quiver(x, y, z, u, v, w)
        axis.set_xlabel('x')
        axis.set_ylabel('y')
        axis.set_zlabel('z')
    elif isinstance(data, Grid) and channel(data).volume == 1 and data.spatial_rank == 3:
        x, y, z = [d.numpy('x,y,z') for d in data.points.vector.unstack_spatial('x,y,z')]
        values = data.values.numpy('x,y,z')
        cmap = plt.get_cmap('viridis')
        norm = matplotlib.colors.Normalize(vmin=np.min(values), vmax=np.max(values))
        colors = cmap(norm(values))
        axis.voxels(values, facecolors=colors, edgecolor='k')
    elif isinstance(data, PointCloud) and data.spatial_rank == 2 and 'vector' in channel(data):
        axis.set_aspect('equal', adjustable='box')
        vector = data.points.shape['vector']
        x, y = math.reshaped_native(data.points, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True)
        u, v = math.reshaped_native(data.values, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True)
        lower_x, lower_y = [float(d) for d in data.bounds.lower.vector]
        upper_x, upper_y = [float(d) for d in data.bounds.upper.vector]
        axis.set_xlim((lower_x, upper_x))
        axis.set_ylim((lower_y, upper_y))
        if data.color.shape:
            color = data.color.numpy(data.shape.non_channel).reshape(-1)
        else:
            color = data.color.native()
        axis.quiver(x, y, u, v, color=color, units='xy', scale=1)
        if data.points.vector.item_names:
            axis.set_xlabel(data.points.vector.item_names[0])
            axis.set_ylabel(data.points.vector.item_names[1])
    elif isinstance(data, PointCloud) and data.spatial_rank == 2:
        axis.set_aspect('equal', adjustable='box')
        lower_x, lower_y = [float(d) for d in data.bounds.lower.vector]
        upper_x, upper_y = [float(d) for d in data.bounds.upper.vector]
        axis.set_xlim((lower_x, upper_x))
        axis.set_ylim((lower_y, upper_y))
        if data.points.shape.non_channel.rank > 1:  # multiple instance / spatial dimensions
            data_list = field.unstack(data, data.points.shape.non_channel[0].name)
            for d in data_list:
                _plot_points(axis, d, **plt_args)
        else:
            _plot_points(axis, data, **plt_args)
    elif isinstance(data, PointCloud) and data.spatial_rank == 3:
        if data.points.shape.non_channel.rank > 1:
            data_list = field.unstack(data, data.points.shape.non_channel[0].name)
            for d in data_list:
                _plot(axis, d, show_color_bar, vmin, vmax, **plt_args)
        else:
            x, y, z = [d.numpy() for d in data.points.vector.unstack_spatial('x,y,z')]
            color = [d.native() for d in data.color.points.unstack(len(x))]
            M = axis.transData.get_matrix()
            x_scale, y_scale, z_scale = M[0, 0], M[1, 1], M[2, 2]
            if isinstance(data.elements, Sphere):
                symbol = 'o'
                size = data.elements.bounding_radius().numpy() * 0.4
            elif isinstance(data.elements, BaseBox):
                symbol = 's'
                size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 0.35
            elif isinstance(data.elements, Point):
                symbol = 'x'
                size = 6 / (0.5 * (x_scale+y_scale+z_scale)/3)
            else:
                symbol = 'X'
                size = data.elements.bounding_radius().numpy()
            axis.scatter(x, y, z, marker=symbol, color=color, s=(size * 0.5 * (x_scale+y_scale+z_scale)/3) ** 2)
        lower_x, lower_y, lower_z = [float(d) for d in data.bounds.lower.vector.unstack_spatial('x,y,z')]
        upper_x, upper_y, upper_z = [float(d) for d in data.bounds.upper.vector.unstack_spatial('x,y,z')]
        axis.set_xlim((lower_x, upper_x))
        axis.set_ylim((lower_y, upper_y))
        axis.set_zlim((lower_z, upper_z))
    else:
        raise NotImplementedError(f"No figure recipe for {data}")