def _plot_scalar_grid(grid: Grid, title, colorbar, cmap, figsize, same_scale): batch_size = grid.shape.batch.volume values = math.join_dimensions(grid.values, grid.shape.channel, 'channel').channel[0] plt_args = {} if same_scale: plt_args['vmin'] = math.min(values).native() plt_args['vmax'] = math.max(values).native() b_values = math.join_dimensions(values, grid.shape.batch, 'batch') fig, axes = plt.subplots(1, batch_size, figsize=figsize) axes = axes if isinstance(axes, np.ndarray) else [axes] for b in range(batch_size): im = axes[b].imshow(b_values.batch[b].numpy('y,x'), origin='lower', cmap=cmap, **plt_args) if title: if isinstance(title, str): sub_title = title elif title is True: sub_title = f"{b} of {grid.shape.batch}" elif isinstance(title, (tuple, list)): sub_title = title[b] else: sub_title = None if sub_title is not None: axes[b].set_title(sub_title) if colorbar: plt.colorbar(im, ax=axes[b]) plt.tight_layout() return fig, axes
def plot(field: SampledField, title=False, colorbar=False, cmap='magma', figsize=(12, 5), same_scale=True, **plt_args): batch_size = field.shape.batch.volume values = math.join_dimensions(field.values, field.shape.channel, 'channel').channel[0] fig, axes = plt.subplots(1, batch_size, figsize=figsize) axes = axes if isinstance(axes, np.ndarray) else [axes] b_values = math.join_dimensions(values, field.shape.batch, 'batch') if title: for b in range(batch_size): if isinstance(title, str): sub_title = title elif title is True: sub_title = f"{b} of {field.shape.batch}" elif isinstance(title, (tuple, list)): sub_title = title[b] else: sub_title = None if sub_title is not None: axes[b].set_title(sub_title) # Individual plots if isinstance(field, Grid) and field.shape.channel.volume == 1: if same_scale: plt_args['vmin'] = math.min(values).native() plt_args['vmax'] = math.max(values).native() for b in range(batch_size): im = axes[b].imshow(b_values.batch[b].numpy('y,x'), origin='lower', cmap=cmap, **plt_args) if colorbar: plt.colorbar(im, ax=axes[b]) elif isinstance(field, Grid): if isinstance(field, StaggeredGrid): field = field.at_centers() for b in range(batch_size): x, y = field.points.vector.unstack_spatial('x,y', to_numpy=True) data = math.join_dimensions(field.values, field.shape.batch, 'batch').batch[b] u, v = data.vector.unstack_spatial('x,y', to_numpy=True) axes[b].quiver(x-u/2, y-v/2, u, v) else: raise NotImplementedError(f"No figure recipe for {field}") plt.tight_layout() return fig, axes
def test_split_dimension(self): grid = math.random_normal(batch=10, x=4, y=3, vector=2) points = math.join_dimensions(grid, grid.shape.spatial, 'points') split = points.points.split(grid.shape.spatial) self.assertEqual(grid.shape, split.shape) math.assert_close(grid, split)
def test_join_dimensions(self): grid = math.random_normal(batch=10, x=4, y=3, vector=2) points = math.join_dimensions(grid, grid.shape.spatial, 'points') self.assertEqual(('batch', 'points', 'vector'), points.shape.names) self.assertEqual(grid.shape.volume, points.shape.volume) self.assertEqual(grid.shape.non_spatial, points.shape.non_spatial)
def list_cells(self, dim_name): center = math.join_dimensions(self.center, self._shape.spatial.names, dim_name) return Cuboid(center, self.half_size)