def _annotate_points(axis, points: math.Tensor, labelled_dim: math.Shape): if points.shape['vector'].size == 2: x, y = math.reshaped_native(points, ['vector', points.shape.without('vector')], to_numpy=True, force_expand=True) if labelled_dim.item_names[0]: x_view = axis.get_xlim()[1] - axis.get_xlim()[0] y_view = axis.get_ylim()[1] - axis.get_ylim()[0] for x_, y_, label in zip(x, y, labelled_dim.item_names[0]): axis.annotate(label, (x_ + .01 * x_view, y_ + .01 * y_view))
def discretize(grid: Grid, filled_fraction=0.25): """ Treats channel dimensions as batch dimensions. """ import numpy as np data = math.reshaped_native(grid.values, [grid.shape.non_spatial, grid.shape.spatial]) ranked_idx = np.argsort(data, axis=-1) filled_idx = ranked_idx[:, int( round(grid.shape.spatial.volume * (1 - filled_fraction))):] filled = np.zeros_like(data) np.put_along_axis(filled, filled_idx, 1, axis=-1) filled_t = math.reshaped_tensor( filled, [grid.shape.non_spatial, grid.shape.spatial]) return grid.with_values(filled_t)
def test_reshaped_native(self): a = math.random_uniform(channel(vector=2) & spatial(x=4, y=3)) nat = math.reshaped_native(a, ['batch', a.shape.spatial, 'vector'], force_expand=False) self.assertEqual((1, 12, 2), nat.shape) nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=False) self.assertEqual((1, 12, 2), nat.shape) nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=['x']) self.assertEqual((1, 12, 2), nat.shape) nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=True) self.assertEqual((10, 12, 2), nat.shape) nat = math.reshaped_native(a, [batch(batch=10), a.shape.spatial, channel(vector=2)], force_expand=['batch']) self.assertEqual((10, 12, 2), nat.shape) nat = math.reshaped_native(a, [a.shape.spatial, channel(vector=2, v2=2)], force_expand=False) self.assertEqual((12, 4), nat.shape) try: math.reshaped_native(a, [channel(vector=2, v2=2)], force_expand=False) except AssertionError as err: print(err) pass
def _plot(data: SampledField, fig: graph_objects.Figure, size: tuple, colormap: str or None, show_color_bar: bool, row: int = None, col: int = None, ): subplot = fig.get_subplot(row, col) vector = data.points.shape['vector'] if data.spatial_rank == 1 and isinstance(data, Grid): x = data.points.vector[0].numpy().flatten() channels = data.values.shape.channel if channels.rank == 1 and channels.get_item_names(0) is not None: for i, name in enumerate(channels.get_item_names(0)): y = math.reshaped_native(real_values(data[{channels.name: i}]), [data.shape.spatial], to_numpy=True) fig.add_trace(graph_objects.Scatter(x=x, y=y, mode='lines+markers', name=name), row=row, col=col) fig.update_layout(showlegend=True) else: for ch_idx in channels.meshgrid(): y = math.reshaped_native(real_values(data[ch_idx]), [data.shape.spatial], to_numpy=True) fig.add_trace(graph_objects.Scatter(x=x, y=y, mode='lines+markers', name='Multi-channel'), row=row, col=col) fig.update_layout(showlegend=False) elif data.spatial_rank == 2 and isinstance(data, Grid) and 'vector' not in data.shape: # heatmap dims = spatial(data) values = real_values(data).numpy(dims.reversed) x = data.points.vector[dims[0].name].dimension(dims[1].name)[0].numpy() y = data.points.vector[dims[1].name].dimension(dims[0].name)[0].numpy() min_val, max_val = numpy.nanmin(values), numpy.nanmax(values) min_val, max_val = min_val if numpy.isfinite(min_val) else 0, max_val if numpy.isfinite(max_val) else 0 color_scale = get_div_map(min_val, max_val, equal_scale=True, colormap=colormap) # color_bar = graph_objects.heatmap.ColorBar(x=1.15) , colorbar=color_bar fig.add_heatmap(row=row, col=col, x=x, y=y, z=values, zauto=False, zmin=min_val, zmax=max_val, colorscale=color_scale, showscale=show_color_bar) subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain', title=dims.names[0]) subplot.yaxis.update(constrain='domain', title=dims.names[1]) elif data.spatial_rank == 2 and isinstance(data, Grid): # vector field if isinstance(data, StaggeredGrid): data = data.at_centers() x, y = math.reshaped_native(data.points.vector[spatial(data)], [vector, data.shape.without(vector)], to_numpy=True, force_expand=True) extra_channels = data.shape.channel.without('vector') data_x, data_y = math.reshaped_native(data.values, [vector, extra_channels, spatial(data)], to_numpy=True, force_expand=True) lower_x, lower_y = [float(l) for l in data.bounds.lower.vector.unstack_spatial('x,y')] upper_x, upper_y = [float(u) for u in data.bounds.upper.vector.unstack_spatial('x,y')] x_range = [lower_x, upper_x] y_range = [lower_y, upper_y] for ch in range(data_x.shape[0]): # quiver = figure_factory.create_quiver(x, y, data_x[ch], data_y[ch], scale=1.0) # 7 points per arrow # fig.add_trace(quiver, row=row, col=col) data_y_flat = data_y[ch].flatten() data_x_flat = data_x[ch].flatten() # lines_y = numpy.stack([y, y + data_y_flat, [None] * len(x)], -1).flatten() # 3 points per arrow # lines_x = numpy.stack([x, x + data_x_flat, [None] * len(x)], -1).flatten() lines_y = numpy.stack([y - data_y_flat / 2, y + data_y_flat / 2, [None] * len(x)], -1).flatten() # 3 points per arrow lines_x = numpy.stack([x - data_x_flat / 2, x + data_x_flat / 2, [None] * len(x)], -1).flatten() name = extra_channels.get_item_names(0)[ch] if extra_channels.rank == 1 and extra_channels.get_item_names(0) is not None else None fig.add_scatter(x=lines_x, y=lines_y, mode='lines', row=row, col=col, name=name) if data_x.shape[0] == 1: fig.update_layout(showlegend=False) subplot.xaxis.update(range=x_range) subplot.yaxis.update(range=y_range) subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain') subplot.yaxis.update(constrain='domain') elif data.spatial_rank == 3 and isinstance(data, Grid) and data.shape.channel.volume == 1: # 3D heatmap values = real_values(data).numpy('z,y,x') x = data.points.vector['x'].numpy('z,y,x') y = data.points.vector['y'].numpy('z,y,x') z = data.points.vector['z'].numpy('z,y,x') min_val, max_val = numpy.nanmin(values), numpy.nanmax(values) min_val, max_val = min_val if numpy.isfinite(min_val) else 0, max_val if numpy.isfinite(max_val) else 0 color_scale = get_div_map(min_val, max_val, equal_scale=True, colormap=colormap) fig.add_volume(x=x.flatten(), y=y.flatten(), z=z.flatten(), value=values.flatten(), showscale=show_color_bar, colorscale=color_scale, cmin=min_val, cmax=max_val, cauto=False, isomin=0.1, isomax=0.8, opacity=0.1, # needs to be small to see through all surfaces surface_count=17, # needs to be a large number for good volume rendering row=row, col=col) fig.update_layout(uirevision=True) elif data.spatial_rank == 3 and isinstance(data, Grid): # 3D vector field if isinstance(data, StaggeredGrid): data = data.at_centers() u = real_values(data).vector['x'].numpy('z,y,x') v = real_values(data).vector['y'].numpy('z,y,x') w = real_values(data).vector['z'].numpy('z,y,x') x = data.points.vector['x'].numpy('z,y,x') y = data.points.vector['y'].numpy('z,y,x') z = data.points.vector['z'].numpy('z,y,x') fig.add_cone(x=x.flatten(), y=y.flatten(), z=z.flatten(), u=u.flatten(), v=v.flatten(), w=w.flatten(), colorscale='Blues', sizemode="absolute", sizeref=1, row=row, col=col) elif isinstance(data, PointCloud) and data.spatial_rank == 2 and 'vector' in channel(data): x, y = math.reshaped_native(data.points, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True) u, v = math.reshaped_native(data.values, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True) lower_x, lower_y = [float(d) for d in data.bounds.lower.vector] upper_x, upper_y = [float(d) for d in data.bounds.upper.vector] subplot.xaxis.update(range=[lower_x, upper_x]) subplot.yaxis.update(range=[lower_y, upper_y]) quiver = figure_factory.create_quiver(x, y, u, v, scale=1.0).data[0] # 7 points per arrow if data.color.shape: # color = data.color.numpy(data.shape.non_channel).reshape(-1) warnings.warn("Multi-colored vector plots not yet supported") else: color = data.color.native() quiver.line.update(color=color) fig.add_trace(quiver, row=row, col=col) if data.points.vector.item_names: subplot.xaxis.update(title=data.points.vector.item_names[0]) subplot.yaxis.update(title=data.points.vector.item_names[1]) subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain') subplot.yaxis.update(constrain='domain') elif isinstance(data, PointCloud) and data.spatial_rank == 2: lower_x, lower_y = [float(d) for d in data.bounds.lower.vector] upper_x, upper_y = [float(d) for d in data.bounds.upper.vector] if data.points.shape.non_channel.rank > 1: data_list = field.unstack(data, data.points.shape.non_channel[0].name) for d in data_list: _plot(d, fig, size, colormap, show_color_bar, row, col) else: x, y = [d.numpy() for d in data.points.vector.unstack_spatial('x,y')] color = data.color.native() subplot_height = (subplot.yaxis.domain[1] - subplot.yaxis.domain[0]) * size[1] if isinstance(data.elements, Sphere): symbol = 'circle' marker_size = data.elements.bounding_radius().numpy() * 1.9 elif isinstance(data.elements, BaseBox): symbol = 'square' marker_size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 1 elif isinstance(data.elements, Point): symbol = 'x' marker_size = 12 / (subplot_height / (upper_y - lower_y)) else: symbol = 'asterisk' marker_size = data.elements.bounding_radius().numpy() marker_size *= subplot_height / (upper_y - lower_y) marker = graph_objects.scatter.Marker(size=marker_size, color=color, sizemode='diameter', symbol=symbol) fig.add_scatter(mode='markers', x=x, y=y, marker=marker, row=row, col=col) subplot.xaxis.update(range=[lower_x, upper_x]) subplot.yaxis.update(range=[lower_y, upper_y]) fig.update_layout(showlegend=False) subplot.xaxis.update(scaleanchor=f'y{subplot.yaxis.plotly_name[5:]}', scaleratio=1, constrain='domain') subplot.yaxis.update(constrain='domain') elif isinstance(data, PointCloud) and data.spatial_rank == 3: lower_x, lower_y, lower_z = [float(d) for d in data.bounds.lower.vector.unstack_spatial('x,y,z')] upper_x, upper_y, upper_z = [float(d) for d in data.bounds.upper.vector.unstack_spatial('x,y,z')] if data.points.shape.non_channel.rank > 1: data_list = field.unstack(data, data.points.shape.non_channel[0].name) for d in data_list: _plot(d, fig, size, colormap, show_color_bar, row, col) else: x, y, z = [d.numpy() for d in data.points.vector.unstack_spatial('x,y,z')] color = data.color.native() # if data.color.shape.instance_rank == 0: # color = str(data.color) # else: # color = [str(d) for d in math.unstack(data.color, instance)] domain_y = fig.layout[subplot.plotly_name].domain.y if isinstance(data.elements, Sphere): symbol = 'circle' marker_size = data.elements.bounding_radius().numpy() * 2 elif isinstance(data.elements, BaseBox): symbol = 'square' marker_size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 1 elif isinstance(data.elements, Point): symbol = 'x' marker_size = 4 / (size[1] * (domain_y[1] - domain_y[0]) / (upper_y - lower_y) * 0.5) else: symbol = 'asterisk' marker_size = data.elements.bounding_radius().numpy() marker_size *= size[1] * (domain_y[1] - domain_y[0]) / (upper_y - lower_y) * 0.5 marker = graph_objects.scatter3d.Marker(size=marker_size, color=color, sizemode='diameter', symbol=symbol) fig.add_scatter3d(mode='markers', x=x, y=y, z=z, marker=marker, row=row, col=col) subplot.xaxis.update(range=[lower_x, upper_x]) subplot.yaxis.update(range=[lower_y, upper_y]) subplot.zaxis.update(range=[lower_z, upper_z]) fig.update_layout(showlegend=False) else: raise NotImplementedError(f"No figure recipe for {data}")
def _plot(axis, data, show_color_bar, vmin, vmax, **plt_args): if isinstance(data, Grid) and data.spatial_rank == 1: x = data.points.staggered_direction[0].vector[0].numpy() requires_legend = False for c in channel(data).meshgrid(names=True): label = ", ".join([i for dim, i in c.items() if isinstance(i, str)]) values = data.values[c].numpy() if values.dtype in (np.complex64, np.complex128): axis.plot(x, values.real, label=f"real({label})" if label else "real") axis.plot(x, values.imag, label=f"imag({label})" if label else "real") requires_legend = True else: axis.plot(x, values, label=label) requires_legend = requires_legend or label if requires_legend: axis.legend() elif isinstance(data, Grid) and channel(data).volume == 1 and data.spatial_rank == 2: dims = spatial(data) if data.bounds.upper.vector.item_names is not None: left, bottom = data.bounds.lower.vector[dims] right, top = data.bounds.upper.vector[dims] else: dim_indices = data.resolution.indices(dims) left, bottom = data.bounds.lower.vector[dim_indices] right, top = data.bounds.upper.vector[dim_indices] extent = (float(left), float(right), float(bottom), float(top)) im = axis.imshow(data.values.numpy(dims.reversed), origin='lower', extent=extent, vmin=vmin, vmax=vmax, **plt_args) if show_color_bar: axis.figure.colorbar(im, ax=axis) # adds a new Axis to the figure axis.set_xlabel(dims.names[0]) axis.set_ylabel(dims.names[1]) elif isinstance(data, Grid) and data.spatial_rank == 2: # vector field if isinstance(data, StaggeredGrid): data = data.at_centers() x, y = [d.numpy('x,y') for d in data.points.vector.unstack_spatial('x,y')] u, v = [d.numpy('x,y') for d in data.values.vector.unstack_spatial('x,y')] color = axis.xaxis.label.get_color() axis.quiver(x, y, u, v, color=color, units='xy', scale=1) axis.set_aspect('equal', adjustable='box') elif isinstance(data, Grid) and channel(data).volume > 1 and data.spatial_rank == 3: x, y, z = [d.numpy('x,y,z') for d in data.points.vector.unstack_spatial('x,y,z')] u, v, w = [d.numpy('x,y,z') for d in data.values.vector.unstack_spatial('x,y,z')] axis.quiver(x, y, z, u, v, w) axis.set_xlabel('x') axis.set_ylabel('y') axis.set_zlabel('z') elif isinstance(data, Grid) and channel(data).volume == 1 and data.spatial_rank == 3: x, y, z = [d.numpy('x,y,z') for d in data.points.vector.unstack_spatial('x,y,z')] values = data.values.numpy('x,y,z') cmap = plt.get_cmap('viridis') norm = matplotlib.colors.Normalize(vmin=np.min(values), vmax=np.max(values)) colors = cmap(norm(values)) axis.voxels(values, facecolors=colors, edgecolor='k') elif isinstance(data, PointCloud) and data.spatial_rank == 2 and 'vector' in channel(data): axis.set_aspect('equal', adjustable='box') vector = data.points.shape['vector'] x, y = math.reshaped_native(data.points, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True) u, v = math.reshaped_native(data.values, [vector, data.shape.without('vector')], to_numpy=True, force_expand=True) lower_x, lower_y = [float(d) for d in data.bounds.lower.vector] upper_x, upper_y = [float(d) for d in data.bounds.upper.vector] axis.set_xlim((lower_x, upper_x)) axis.set_ylim((lower_y, upper_y)) if data.color.shape: color = data.color.numpy(data.shape.non_channel).reshape(-1) else: color = data.color.native() axis.quiver(x, y, u, v, color=color, units='xy', scale=1) if data.points.vector.item_names: axis.set_xlabel(data.points.vector.item_names[0]) axis.set_ylabel(data.points.vector.item_names[1]) elif isinstance(data, PointCloud) and data.spatial_rank == 2: axis.set_aspect('equal', adjustable='box') lower_x, lower_y = [float(d) for d in data.bounds.lower.vector] upper_x, upper_y = [float(d) for d in data.bounds.upper.vector] axis.set_xlim((lower_x, upper_x)) axis.set_ylim((lower_y, upper_y)) if data.points.shape.non_channel.rank > 1: # multiple instance / spatial dimensions data_list = field.unstack(data, data.points.shape.non_channel[0].name) for d in data_list: _plot_points(axis, d, **plt_args) else: _plot_points(axis, data, **plt_args) elif isinstance(data, PointCloud) and data.spatial_rank == 3: if data.points.shape.non_channel.rank > 1: data_list = field.unstack(data, data.points.shape.non_channel[0].name) for d in data_list: _plot(axis, d, show_color_bar, vmin, vmax, **plt_args) else: x, y, z = [d.numpy() for d in data.points.vector.unstack_spatial('x,y,z')] color = [d.native() for d in data.color.points.unstack(len(x))] M = axis.transData.get_matrix() x_scale, y_scale, z_scale = M[0, 0], M[1, 1], M[2, 2] if isinstance(data.elements, Sphere): symbol = 'o' size = data.elements.bounding_radius().numpy() * 0.4 elif isinstance(data.elements, BaseBox): symbol = 's' size = math.mean(data.elements.bounding_half_extent(), 'vector').numpy() * 0.35 elif isinstance(data.elements, Point): symbol = 'x' size = 6 / (0.5 * (x_scale+y_scale+z_scale)/3) else: symbol = 'X' size = data.elements.bounding_radius().numpy() axis.scatter(x, y, z, marker=symbol, color=color, s=(size * 0.5 * (x_scale+y_scale+z_scale)/3) ** 2) lower_x, lower_y, lower_z = [float(d) for d in data.bounds.lower.vector.unstack_spatial('x,y,z')] upper_x, upper_y, upper_z = [float(d) for d in data.bounds.upper.vector.unstack_spatial('x,y,z')] axis.set_xlim((lower_x, upper_x)) axis.set_ylim((lower_y, upper_y)) axis.set_zlim((lower_z, upper_z)) else: raise NotImplementedError(f"No figure recipe for {data}")