예제 #1
0
def plotfield(field,
              show_time=None,
              domain=None,
              depth_level=0,
              projection=None,
              land=True,
              vmin=None,
              vmax=None,
              savefile=None,
              **kwargs):
    """Function to plot a Parcels Field

    :param show_time: Time at which to show the Field
    :param domain: dictionary (with keys 'N', 'S', 'E', 'W') defining domain to show
    :param depth_level: depth level to be plotted (default 0)
    :param projection: type of cartopy projection to use (default PlateCarree)
    :param land: Boolean whether to show land. This is ignored for flat meshes
    :param vmin: minimum colour scale (only in single-plot mode)
    :param vmax: maximum colour scale (only in single-plot mode)
    :param savefile: Name of a file to save the plot to
    :param animation: Boolean whether result is a single plot, or an animation
    """

    if type(field) is VectorField:
        spherical = True if field.U.grid.mesh == 'spherical' else False
        field = [field.U, field.V]
        plottype = 'vector'
    elif type(field) is Field:
        spherical = True if field.grid.mesh == 'spherical' else False
        field = [field]
        plottype = 'scalar'
    else:
        raise RuntimeError('field needs to be a Field or VectorField object')

    if field[0].grid.gtype in [
            GridCode.CurvilinearZGrid, GridCode.CurvilinearSGrid
    ]:
        logger.warning(
            'Field.show() does not always correctly determine the domain for curvilinear grids. '
            'Use plotting with caution and perhaps use domain argument as in the NEMO 3D tutorial'
        )

    plt, fig, ax, cartopy = create_parcelsfig_axis(spherical,
                                                   land,
                                                   projection=projection)
    if plt is None:
        return None, None, None, None  # creating axes was not possible

    data = {}
    plotlon = {}
    plotlat = {}
    for i, fld in enumerate(field):
        show_time = fld.grid.time[0] if show_time is None else show_time
        if fld.grid.defer_load:
            fld.fieldset.computeTimeChunk(show_time, 1)
        (idx, periods) = fld.time_index(show_time)
        show_time -= periods * (fld.grid.time_full[-1] - fld.grid.time_full[0])
        if show_time > fld.grid.time[-1] or show_time < fld.grid.time[0]:
            raise TimeExtrapolationError(show_time, field=fld, msg='show_time')

        latN, latS, lonE, lonW = parsedomain(domain, fld)
        if isinstance(fld.grid, CurvilinearGrid):
            plotlon[i] = fld.grid.lon[latS:latN, lonW:lonE]
            plotlat[i] = fld.grid.lat[latS:latN, lonW:lonE]
        else:
            plotlon[i] = fld.grid.lon[lonW:lonE]
            plotlat[i] = fld.grid.lat[latS:latN]
        if i > 0 and not np.allclose(plotlon[i], plotlon[0]):
            raise RuntimeError(
                'VectorField needs to be on an A-grid for plotting')
        if fld.grid.time.size > 1:
            if fld.grid.zdim > 1:
                data[i] = np.squeeze(
                    fld.temporal_interpolate_fullfield(idx,
                                                       show_time))[depth_level,
                                                                   latS:latN,
                                                                   lonW:lonE]
            else:
                data[i] = np.squeeze(
                    fld.temporal_interpolate_fullfield(idx,
                                                       show_time))[latS:latN,
                                                                   lonW:lonE]
        else:
            if fld.grid.zdim > 1:
                data[i] = np.squeeze(fld.data)[depth_level, latS:latN,
                                               lonW:lonE]
            else:
                data[i] = np.squeeze(fld.data)[latS:latN, lonW:lonE]

    if plottype == 'vector':
        if field[0].interp_method == 'cgrid_velocity':
            logger.warning_once(
                'Plotting a C-grid velocity field is achieved via an A-grid projection, reducing the plot accuracy'
            )
            d = np.empty_like(data[0])
            d[:-1, :] = (data[0][:-1, :] + data[0][1:, :]) / 2.
            d[-1, :] = data[0][-1, :]
            data[0] = d
            d = np.empty_like(data[0])
            d[:, :-1] = (data[0][:, :-1] + data[0][:, 1:]) / 2.
            d[:, -1] = data[0][:, -1]
            data[1] = d

        spd = data[0]**2 + data[1]**2
        speed = np.where(spd > 0, np.sqrt(spd), 0)
        vmin = speed.min() if vmin is None else vmin
        vmax = speed.max() if vmax is None else vmax
        if isinstance(field[0].grid, CurvilinearGrid):
            x, y = plotlon[0], plotlat[0]
        else:
            x, y = np.meshgrid(plotlon[0], plotlat[0])
        u = np.where(speed > 0., data[0] / speed, 0)
        v = np.where(speed > 0., data[1] / speed, 0)
        if cartopy:
            cs = ax.quiver(np.asarray(x),
                           np.asarray(y),
                           np.asarray(u),
                           np.asarray(v),
                           speed,
                           cmap=plt.cm.gist_ncar,
                           clim=[vmin, vmax],
                           scale=50,
                           transform=cartopy.crs.PlateCarree())
        else:
            cs = ax.quiver(x,
                           y,
                           u,
                           v,
                           speed,
                           cmap=plt.cm.gist_ncar,
                           clim=[vmin, vmax],
                           scale=50)
    else:
        vmin = data[0].min() if vmin is None else vmin
        vmax = data[0].max() if vmax is None else vmax
        assert len(data[0].shape) == 2
        if field[0].interp_method == 'cgrid_tracer':
            d = data[0][1:, 1:]
        elif field[0].interp_method == 'cgrid_velocity':
            if field[0].fieldtype == 'U':
                d = np.empty_like(data[0])
                d[:-1, :-1] = (data[0][1:, :-1] + data[0][1:, 1:]) / 2.
            elif field[0].fieldtype == 'V':
                d = np.empty_like(data[0])
                d[:-1, :-1] = (data[0][:-1, 1:] + data[0][1:, 1:]) / 2.
            else:  # W
                d = data[0][1:, 1:]
        else:  # if A-grid
            d = (data[0][:-1, :-1] + data[0][1:, :-1] + data[0][:-1, 1:] +
                 data[0][1:, 1:]) / 4.
            d = np.where(data[0][:-1, :-1] == 0, 0, d)
            d = np.where(data[0][1:, :-1] == 0, 0, d)
            d = np.where(data[0][1:, 1:] == 0, 0, d)
            d = np.where(data[0][:-1, 1:] == 0, 0, d)
        if cartopy:
            cs = ax.pcolormesh(plotlon[0],
                               plotlat[0],
                               d,
                               transform=cartopy.crs.PlateCarree())
        else:
            cs = ax.pcolormesh(plotlon[0], plotlat[0], d)

    if cartopy is None:
        ax.set_xlim(np.nanmin(plotlon[0]), np.nanmax(plotlon[0]))
        ax.set_ylim(np.nanmin(plotlat[0]), np.nanmax(plotlat[0]))
    elif domain is not None:
        ax.set_extent([
            np.nanmin(plotlon[0]),
            np.nanmax(plotlon[0]),
            np.nanmin(plotlat[0]),
            np.nanmax(plotlat[0])
        ],
                      crs=cartopy.crs.PlateCarree())
    cs.cmap.set_over('k')
    cs.cmap.set_under('w')
    cs.set_clim(vmin, vmax)

    cartopy_colorbar(cs, plt, fig, ax)

    timestr = parsetimestr(field[0].grid.time_origin, show_time)
    titlestr = kwargs.pop('titlestr', '')
    if field[0].grid.zdim > 1:
        if field[0].grid.gtype in [
                GridCode.CurvilinearZGrid, GridCode.RectilinearZGrid
        ]:
            gphrase = 'depth'
            depth_or_level = field[0].grid.depth[depth_level]
        else:
            gphrase = 'level'
            depth_or_level = depth_level
        depthstr = ' at %s %g ' % (gphrase, depth_or_level)
    else:
        depthstr = ''
    if plottype == 'vector':
        ax.set_title(titlestr + 'Velocity field' + depthstr + timestr)
    else:
        ax.set_title(titlestr + field[0].name + depthstr + timestr)

    if not spherical:
        ax.set_xlabel('Zonal distance [m]')
        ax.set_ylabel('Meridional distance [m]')

    plt.draw()

    if savefile:
        plt.savefig(savefile)
        logger.info('Plot saved to ' + savefile + '.png')
        plt.close()

    return plt, fig, ax, cartopy
예제 #2
0
def plotfield(field, show_time=None, domain=None, depth_level=0, projection=None, land=True,
              vmin=None, vmax=None, savefile=None, **kwargs):
    """Function to plot a Parcels Field

    :param show_time: Time at which to show the Field
    :param domain: Four-vector (latN, latS, lonE, lonW) defining domain to show
    :param depth_level: depth level to be plotted (default 0)
    :param projection: type of cartopy projection to use (default PlateCarree)
    :param land: Boolean whether to show land. This is ignored for flat meshes
    :param vmin: minimum colour scale (only in single-plot mode)
    :param vmax: maximum colour scale (only in single-plot mode)
    :param savefile: Name of a file to save the plot to
    :param animation: Boolean whether result is a single plot, or an animation
    """

    if type(field) is VectorField:
        spherical = True if field.U.grid.mesh == 'spherical' else False
        field = [field.U, field.V]
        plottype = 'vector'
    elif type(field) is Field:
        spherical = True if field.grid.mesh == 'spherical' else False
        field = [field]
        plottype = 'scalar'
    else:
        raise RuntimeError('field needs to be a Field or VectorField object')

    plt, fig, ax, cartopy = create_parcelsfig_axis(spherical, land, projection=projection)
    if plt is None:
        return None, None, None, None  # creating axes was not possible

    data = {}
    plotlon = {}
    plotlat = {}
    for i, fld in enumerate(field):
        show_time = fld.grid.time[0] if show_time is None else show_time
        if fld.grid.defer_load:
            fld.fieldset.computeTimeChunk(show_time, 1)
        (idx, periods) = fld.time_index(show_time)
        show_time -= periods * (fld.grid.time[-1] - fld.grid.time[0])
        if show_time > fld.grid.time[-1] or show_time < fld.grid.time[0]:
            raise TimeExtrapolationError(show_time, field=fld, msg='show_time')

        latN, latS, lonE, lonW = parsedomain(domain, fld)
        if isinstance(fld.grid, CurvilinearGrid):
            plotlon[i] = fld.grid.lon[latS:latN, lonW:lonE]
            plotlat[i] = fld.grid.lat[latS:latN, lonW:lonE]
        else:
            plotlon[i] = fld.grid.lon[lonW:lonE]
            plotlat[i] = fld.grid.lat[latS:latN]
        if i > 0 and not np.allclose(plotlon[i], plotlon[0]):
            raise RuntimeError('VectorField needs to be on an A-grid for plotting')
        if fld.grid.time.size > 1:
            if fld.grid.zdim > 1:
                data[i] = np.squeeze(fld.temporal_interpolate_fullfield(idx, show_time))[depth_level, latS:latN, lonW:lonE]
            else:
                data[i] = np.squeeze(fld.temporal_interpolate_fullfield(idx, show_time))[latS:latN, lonW:lonE]
        else:
            if fld.grid.zdim > 1:
                data[i] = np.squeeze(fld.data)[depth_level, latS:latN, lonW:lonE]
            else:
                data[i] = np.squeeze(fld.data)[latS:latN, lonW:lonE]

    if plottype is 'vector':
        spd = data[0] ** 2 + data[1] ** 2
        speed = np.sqrt(spd, where=spd > 0)
        vmin = speed.min() if vmin is None else vmin
        vmax = speed.max() if vmax is None else vmax
        if isinstance(field[0].grid, CurvilinearGrid):
            x, y = plotlon[0], plotlat[0]
        else:
            x, y = np.meshgrid(plotlon[0], plotlat[0])
        nonzerospd = speed != 0
        u, v = (np.zeros_like(data[0]) * np.nan, np.zeros_like(data[1]) * np.nan)
        np.place(u, nonzerospd, data[0][nonzerospd] / speed[nonzerospd])
        np.place(v, nonzerospd, data[1][nonzerospd] / speed[nonzerospd])
        if cartopy:
            cs = ax.quiver(x, y, u, v, speed, cmap=plt.cm.gist_ncar, clim=[vmin, vmax], scale=50, transform=cartopy.crs.PlateCarree())
        else:
            cs = ax.quiver(x, y, u, v, speed, cmap=plt.cm.gist_ncar, clim=[vmin, vmax], scale=50)
    else:
        vmin = data[0].min() if vmin is None else vmin
        vmax = data[0].max() if vmax is None else vmax
        if cartopy:
            cs = ax.pcolormesh(plotlon[0], plotlat[0], data[0], transform=cartopy.crs.PlateCarree())
        else:
            cs = ax.pcolormesh(plotlon[0], plotlat[0], data[0])

    if cartopy is None or projection is None:
        ax.set_xlim(np.nanmin(plotlon[0]), np.nanmax(plotlon[0]))
        ax.set_ylim(np.nanmin(plotlat[0]), np.nanmax(plotlat[0]))
    elif domain is not None:
        ax.set_extent([np.nanmin(plotlon[0]), np.nanmax(plotlon[0]), np.nanmin(plotlat[0]), np.nanmax(plotlat[0])])
    cs.cmap.set_over('k')
    cs.cmap.set_under('w')
    cs.set_clim(vmin, vmax)

    cartopy_colorbar(cs, plt, fig, ax)

    timestr = parsetimestr(field[0].grid.time_origin, show_time)
    titlestr = kwargs.pop('titlestr', '')
    if field[0].grid.zdim > 1:
        if field[0].grid.gtype in [GridCode.CurvilinearZGrid, GridCode.RectilinearZGrid]:
            gphrase = 'depth'
            depth_or_level = field[0].grid.depth[depth_level]
        else:
            gphrase = 'level'
            depth_or_level = depth_level
        depthstr = ' at %s %g ' % (gphrase, depth_or_level)
    else:
        depthstr = ''
    if plottype is 'vector':
        ax.set_title(titlestr + 'Velocity field' + depthstr + timestr)
    else:
        ax.set_title(titlestr + field[0].name + depthstr + timestr)

    if not spherical:
        ax.set_xlabel('Zonal distance [m]')
        ax.set_ylabel('Meridional distance [m]')

    plt.draw()

    if savefile:
        plt.savefig(savefile)
        logger.info('Plot saved to ' + savefile + '.png')
        plt.close()

    return plt, fig, ax, cartopy
예제 #3
0
    def computeTimeChunk(self, time, dt):
        """Load a chunk of three data time steps into the FieldSet.
        This is used when FieldSet uses data imported from netcdf,
        with default option deferred_load. The loaded time steps are at or immediatly before time
        and the two time steps immediately following time if dt is positive (and inversely for negative dt)
        :param time: Time around which the FieldSet chunks are to be loaded. Time is provided as a double, relatively to Fieldset.time_origin
        :param dt: time step of the integration scheme
        """
        signdt = np.sign(dt)
        nextTime = np.infty if dt > 0 else -np.infty

        for g in self.gridset.grids:
            g.update_status = 'not_updated'
        for f in self.get_fields():
            if type(f) in [VectorField, NestedField, SummedField] or not f.grid.defer_load:
                continue
            if f.grid.update_status == 'not_updated':
                nextTime_loc = f.grid.computeTimeChunk(f, time, signdt)
                if time == nextTime_loc and signdt != 0:
                    raise TimeExtrapolationError(time, field=f, msg='In fset.computeTimeChunk')
            nextTime = min(nextTime, nextTime_loc) if signdt >= 0 else max(nextTime, nextTime_loc)

        for f in self.get_fields():
            if type(f) in [VectorField, NestedField, SummedField] or not f.grid.defer_load or f.dataFiles is None:
                continue
            g = f.grid
            if g.update_status == 'first_updated':  # First load of data
                data = da.empty((g.tdim, g.zdim, g.ydim-2*g.meridional_halo, g.xdim-2*g.zonal_halo), dtype=np.float32)
                f.loaded_time_indices = range(3)
                for tind in f.loaded_time_indices:
                    for fb in f.filebuffers:
                        if fb is not None:
                            fb.dataset.close()

                    data = f.computeTimeChunk(data, tind)
                data = f.rescale_and_set_minmax(data)
                f.data = f.reshape(data)
                if not f.chunk_set:
                    f.chunk_setup()
                if len(g.load_chunk) > 0:
                    g.load_chunk = np.where(g.load_chunk == 2, 1, g.load_chunk)
                    g.load_chunk = np.where(g.load_chunk == 3, 0, g.load_chunk)
            elif g.update_status == 'updated':
                data = da.empty((g.tdim, g.zdim, g.ydim-2*g.meridional_halo, g.xdim-2*g.zonal_halo), dtype=np.float32)
                if signdt >= 0:
                    f.loaded_time_indices = [2]
                    f.filebuffers[0].dataset.close()
                    f.filebuffers[:2] = f.filebuffers[1:]
                    data = f.computeTimeChunk(data, 2)
                else:
                    f.loaded_time_indices = [0]
                    f.filebuffers[2].dataset.close()
                    f.filebuffers[1:] = f.filebuffers[:2]
                    data = f.computeTimeChunk(data, 0)
                data = f.rescale_and_set_minmax(data)
                if signdt >= 0:
                    data = f.reshape(data)[2:, :]
                    f.data = da.concatenate([f.data[1:, :], data], axis=0)
                else:
                    data = f.reshape(data)[0:1, :]
                    f.data = da.concatenate([data, f.data[:2, :]], axis=0)
                if len(g.load_chunk) > 0:
                    if signdt >= 0:
                        for block_id in range(len(g.load_chunk)):
                            if g.load_chunk[block_id] == 2:
                                if f.data_chunks[block_id] is None:
                                    # file chunks were never loaded.
                                    # happens when field not called by kernel, but shares a grid with another field called by kernel
                                    break
                                block = f.get_block(block_id)
                                f.data_chunks[block_id][:2] = f.data_chunks[block_id][1:]
                                f.data_chunks[block_id][2] = np.array(f.data.blocks[(slice(3),)+block][2])
                    else:
                        for block_id in range(len(g.load_chunk)):
                            if g.load_chunk[block_id] == 2:
                                if f.data_chunks[block_id] is None:
                                    # file chunks were never loaded.
                                    # happens when field not called by kernel, but shares a grid with another field called by kernel
                                    break
                                block = f.get_block(block_id)
                                f.data_chunks[block_id][1:] = f.data_chunks[block_id][:2]
                                f.data_chunks[block_id][0] = np.array(f.data.blocks[(slice(3),)+block][0])
        # do user-defined computations on fieldset data
        if self.compute_on_defer:
            self.compute_on_defer(self)

        if abs(nextTime) == np.infty or np.isnan(nextTime):  # Second happens when dt=0
            return nextTime
        else:
            nSteps = int((nextTime - time) / dt)
            if nSteps == 0:
                return nextTime
            else:
                return time + nSteps * dt