Exemple #1
0
def _plot_scalar_grid(grid: Grid, title, colorbar, cmap, figsize, same_scale):
    batch_size = grid.shape.batch.volume
    values = math.join_dimensions(grid.values, grid.shape.channel,
                                  'channel').channel[0]
    plt_args = {}
    if same_scale:
        plt_args['vmin'] = math.min(values).native()
        plt_args['vmax'] = math.max(values).native()
    b_values = math.join_dimensions(values, grid.shape.batch, 'batch')
    fig, axes = plt.subplots(1, batch_size, figsize=figsize)
    axes = axes if isinstance(axes, np.ndarray) else [axes]
    for b in range(batch_size):
        im = axes[b].imshow(b_values.batch[b].numpy('y,x'),
                            origin='lower',
                            cmap=cmap,
                            **plt_args)
        if title:
            if isinstance(title, str):
                sub_title = title
            elif title is True:
                sub_title = f"{b} of {grid.shape.batch}"
            elif isinstance(title, (tuple, list)):
                sub_title = title[b]
            else:
                sub_title = None
            if sub_title is not None:
                axes[b].set_title(sub_title)
        if colorbar:
            plt.colorbar(im, ax=axes[b])
    plt.tight_layout()
    return fig, axes
Exemple #2
0
def plot(field: SampledField, title=False, colorbar=False, cmap='magma', figsize=(12, 5), same_scale=True, **plt_args):
    batch_size = field.shape.batch.volume
    values = math.join_dimensions(field.values, field.shape.channel, 'channel').channel[0]
    fig, axes = plt.subplots(1, batch_size, figsize=figsize)
    axes = axes if isinstance(axes, np.ndarray) else [axes]
    b_values = math.join_dimensions(values, field.shape.batch, 'batch')
    if title:
        for b in range(batch_size):
            if isinstance(title, str):
                sub_title = title
            elif title is True:
                sub_title = f"{b} of {field.shape.batch}"
            elif isinstance(title, (tuple, list)):
                sub_title = title[b]
            else:
                sub_title = None
            if sub_title is not None:
                axes[b].set_title(sub_title)
    # Individual plots
    if isinstance(field, Grid) and field.shape.channel.volume == 1:
        if same_scale:
            plt_args['vmin'] = math.min(values).native()
            plt_args['vmax'] = math.max(values).native()
        for b in range(batch_size):
            im = axes[b].imshow(b_values.batch[b].numpy('y,x'), origin='lower', cmap=cmap, **plt_args)
            if colorbar:
                plt.colorbar(im, ax=axes[b])
    elif isinstance(field, Grid):
        if isinstance(field, StaggeredGrid):
            field = field.at_centers()
        for b in range(batch_size):
            x, y = field.points.vector.unstack_spatial('x,y', to_numpy=True)
            data = math.join_dimensions(field.values, field.shape.batch, 'batch').batch[b]
            u, v = data.vector.unstack_spatial('x,y', to_numpy=True)
            axes[b].quiver(x-u/2, y-v/2, u, v)
    else:
        raise NotImplementedError(f"No figure recipe for {field}")
    plt.tight_layout()
    return fig, axes
Exemple #3
0
 def test_split_dimension(self):
     grid = math.random_normal(batch=10, x=4, y=3, vector=2)
     points = math.join_dimensions(grid, grid.shape.spatial, 'points')
     split = points.points.split(grid.shape.spatial)
     self.assertEqual(grid.shape, split.shape)
     math.assert_close(grid, split)
Exemple #4
0
 def test_join_dimensions(self):
     grid = math.random_normal(batch=10, x=4, y=3, vector=2)
     points = math.join_dimensions(grid, grid.shape.spatial, 'points')
     self.assertEqual(('batch', 'points', 'vector'), points.shape.names)
     self.assertEqual(grid.shape.volume, points.shape.volume)
     self.assertEqual(grid.shape.non_spatial, points.shape.non_spatial)
Exemple #5
0
 def list_cells(self, dim_name):
     center = math.join_dimensions(self.center, self._shape.spatial.names,
                                   dim_name)
     return Cuboid(center, self.half_size)