Example #1
0
def extrapolate_pha(lon, lat, lonr, latr, maskr, arr):

    cx, cy = np.cos(arr), np.sin(arr)
    cx = missing_interp(lon, lat, cx)
    cy = missing_interp(lon, lat, cy)
    cx = grid2xy(cx, xo=lonr, yo=latr)
    cy = grid2xy(cy, xo=lonr, yo=latr)

    if np.any(cx.mask):
        tmp = cx.getValue()[~cx.mask]
        pt = cx.mask.nonzero()[0]
        F = interp1d(np.arange(0, len(tmp)),
                     tmp,
                     'nearest',
                     fill_value='extrapolate')
        cx[cx.mask] = F(pt)

        tmp = cy.getValue()[~cy.mask]
        pt = cy.mask.nonzero()[0]
        F = interp1d(np.arange(0, len(tmp)),
                     tmp,
                     'nearest',
                     fill_value='extrapolate')
        cy[cy.mask] = F(pt)

    arr = np.arctan2(cy, cx)
    return arr
Example #2
0
    def _create_nc_gr3(self, ncfile, var):
        data = cdms2.open(ncfile)

        lon = self.hgrid.longitude
        lat = self.hgrid.latitude
        src = fill2d(data[var][:], method='carg')
        time0 = [
            t.torelative('days since 1-1-1').value
            for t in data['time'].asRelativeTime()
        ]
        tin = create_time(np.ones(len(lon)) * date2num(self.t0) + 1,
                          units='days since 1-1-1')
        tb = grid2xy(src, xo=lon, yo=lat, method='linear', to=tin)

        if np.any(tb.mask == True):
            bad = (tb.mask == True).nonzero()[0]
            tin_bad = create_time(np.ones(len(bad)) * date2num(self.t0) + 1,
                                  units='days since 1-1-1')
            tb[bad] = grid2xy(src,
                              xo=np.array(lon)[bad].tolist(),
                              yo=np.array(lat)[bad].tolist(),
                              method='nearest',
                              to=tin_bad)

        self._create_constante_gr3(tb)
Example #3
0
def extrapolate_pha(lon, lat, lonr, latr, maskr, arr):

    cx, cy = np.cos(arr), np.sin(arr)
    cx = missing_interp(lon, lat, cx)
    cy = missing_interp(lon, lat, cy)
    cx = grid2xy(cx, xo=lonr, yo=latr)
    cy = grid2xy(cy, xo=lonr, yo=latr)
    arr = np.arctan2(cy, cx)
    #	arr = np.ma.masked_where(maskr==0, arr)
    return arr
Example #4
0
def extrapolate_amp(lon, lat, lonr, latr, maskr, arr):
    arr = missing_interp(lon, lat, arr)
    arri = grid2xy(arr, xo=lonr, yo=latr)

    if np.any(arri.mask):
        tmp = arri.getValue()[~arri.mask]
        pt = arri.mask.nonzero()[0]
        F = interp1d(np.arange(0, len(tmp)),
                     tmp,
                     'nearest',
                     fill_value='extrapolate')
        arri[arri.mask] = F(pt)

    return arri
Example #5
0
def plot_scattered_locs(lons,
                        lats,
                        depths,
                        slice_type=None,
                        interval=None,
                        plotter=None,
                        lon=None,
                        lat=None,
                        level=None,
                        label='',
                        lon_bounds_margin=.1,
                        lat_bounds_margin=.1,
                        data=None,
                        warn=True,
                        bathy=None,
                        xybathy=None,
                        secbathy=None,
                        size=20,
                        color='#2ca02c',
                        linewidth=0.4,
                        edgecolor='k',
                        add_profile_line=None,
                        add_bathy=True,
                        add_minimap=True,
                        add_section_bathy=True,
                        fig=None,
                        title="{long_name}",
                        register_sm=True,
                        depthshade=False,
                        legend=False,
                        colorbar=True,
                        **kwargs):
    """Plot scattered localisations

    Parameters
    ----------
    lons: n-D array
    lats: n-D array
    depths: n-D array
    slice_type: one of "3d"/None, "2d", "zonal", "meridional", "horizontal"
        The way to slice the observations.
        "3d"/"2d" are 3D/2D view of all observations.
        Other slices make a selection with a range (``interval``).
    interval: None, tuple of float
        Interval for selecting valid data
        Required if slice_type is not "3d"/None/"2d".
    map_<param>:
        <param> is passed to :func:`create_map`
    section_<param>:
        <param> is passed to :func:`vacumm.misc.plot.section2`
    minimap_<param>:
        <param> is passed to :func:`vacumm.misc.plot.add_map_box`

    Todo
    ----
    Add time support.
    """
    # Inits
    if cdms2.isVariable(data):
        data = data.asma()
    kwmap = kwfilter(kwargs, 'map_')
    kwminimap = kwfilter(kwargs, 'minimap_')
    kwsecbat = kwfilter(kwargs, 'section_bathy_')
    kwsection = kwfilter(kwargs, 'section_')
    kwplt = kwfilter(kwargs, 'plotter_')
    dict_check_defaults(kwmap, **kwplt)
    dict_check_defaults(kwsection, **kwplt)
    kwpf = kwfilter(kwargs, 'add_profile_line')
    kwleg = kwfilter(kwargs, 'legend')
    kwcb = kwfilter(kwargs, 'colorbar')
    long_name = get_long_name(data)
    units = getattr(data, 'units', '')
    if long_name is None:
        long_name = "Locations"
    if not title:
        title = None
    elif title is True:
        title = long_name
    else:
        title = title.format(**locals())

    # Slice type
    if slice_type is None:
        slice_type = "3d"
    else:
        slice_type = str(slice_type).lower()
    valid_slice_types = [
        '3d', "2d", 'zonal', 'merid', 'horiz', 'bottom', 'surf'
    ]
    assert slice_type in valid_slice_types, ('Invalid slice type. '
                                             'It must be one of: ' +
                                             ', '.join(valid_slice_types))

    # Numeric horizontal coordinates
    xx = lons[:].copy()
    yy = lats[:].copy()

    # Profiles?
    profiles = (not isinstance(depths, str)
                and (isaxis(depths) or N.shape(depths) != N.shape(xx) or
                     (data is not None and data.ndim == 2)))

    # Force some options
    if not profiles or slice_type not in ('3d', 'merid', 'zonal'):
        add_profile_line = False
    elif add_profile_line is None:
        add_profile_line = True

    # Bathymetry
    need_xybathy = int(add_profile_line)
    if depths == 'bottom' and slice_type not in ('bottom', '2d'):
        need_xybathy = 2
    if need_xybathy and xybathy is not None:

        if bathy is not None:
            xybathy = grid2xy(bathy, lons, lats)
            if xybathy.mask.all():
                if warn:
                    sonat_warn(
                        'Bathymetry is fully masked at bottom locs. Skipping...'
                    )
                if need_xybathy == 2:
                    return
                xybathy = None
        elif need_xybathy == 2 and warn:  # we really need it
            sonat_warn('Bathymetry is needed at obs locations. Skipping...')
            return
    if xybathy is None:
        add_profile_line = False

    # Special depths: surf and bottom
    indepths = depths
    if (depths == 'surf' and slice_type != 'surf'):  # surface
        depths = N.zeros(len(lons))
    elif (depths == 'bottom' and slice_type not in ('bottom', '2d')):  # bottom
        depths = -xybathy
        if interval is not None and N.isscalar(interval[0]):
            interval = (depths + interval[0], depths + interval[1])

    # Numeric vertical coordinates
    strdepths = isinstance(depths, str)
    if not strdepths:
        zz = N.array(depths[:], copy=True)

    # Shape
    if data is not None:
        dshape = data.shape
    elif not profiles or strdepths:
        dshape = xx.shape
    elif zz.ndim == 2:
        dshape = zz.shape
    else:
        dshape = zz.shape + xx.shape

    # Masking outside interval
    if (slice_type != '3d' and slice_type != '2d'
            and (slice_type != 'surf' or depths != 'surf')
            and (slice_type != 'bottom' or depths != 'bottom')):

        assert interval is not None, (
            'You must provide a valid '
            '"interval" for slicing scattered locations')
        stype = 'horiz' if slice_type in ('surf', 'bottom') else slice_type
        data = mask_scattered_locs(xx, yy, depths, stype, interval, data=data)
        if data is None:
            return

    # Get the full mask: (np), or (nz, np) for profiles
    # - mask with data
    if data is not None:
        if data.dtype.char == '?':
            mask = data
            data = None
        else:
            mask = N.ma.getmaskarray(data)
    else:
        mask = N.zeros(dshape)
    # - mask with coordinates
    if N.ma.isMA(xx) or N.ma.isMA(yy):  # lons/lats
        xymask = N.ma.getmaskarray(xx) | N.ma.getmaskarray(yy)
        mask |= N.resize(mask, dshape)
    if not strdepths and N.ma.isMA(zz):  # depths
        if profiles:
            zmask = N.ma.getmaskarray(zz)
            if zz.ndim == 1:
                zmask = N.repeat(N.ma.resize(N.ma.getmaskarray(zmask),
                                             (-1, 1)),
                                 xx.size,
                                 axis=1)
            mask |= zmask
        else:
            mask |= N.ma.getmaskarray(zz)
    # - check
    if mask.all():
        if warn:
            sonat_warn('All your data are masked')
        return
    # - mask back
    xymask = mask if mask.ndim == 1 else mask.all(axis=0)
    xx = N.ma.masked_where(xymask, xx, copy=False)
    yy = N.ma.masked_where(xymask, yy, copy=False)
    if not strdepths:
        if mask.shape == zz.shape:
            zz = N.ma.masked_where(mask, zz, copy=False)
        elif zz.ndim == 1:
            zz = N.ma.masked_where(mask.all(axis=1), zz, copy=False)
    if data is not None:
        data = N.ma.masked_where(mask, data, copy=0)

    # Plotter as Axes
    if isinstance(plotter, P.Axes):
        ax = plotter
        fig = ax.get_figure()
        plotter = None
    elif plotter is None:
        ax = None
    elif isinstance(plotter, Plot):
        ax = plotter.axes
    else:
        raise SONATError('Plotter must be matplotlib Axes instance or '
                         'a vacumm Plot instance')
    if slice_type == '3d':
        if ax is None:
            ax = '3d'
        elif not isinstance(ax, Axes3D):
            sonat_warn("Requesting 3D plot but provided axes are not 3D."
                       " Skipping...")
            axes = None

    # Coordinate bounds
    if level is None and slice_type in ['3d', 'zonal', 'merid']:
        if strdepths or zz.min() == 0:
            level_min = -200  # Fall back to this min depth
        else:
            level_min = 1.1 * zz.min()
        level = (level_min, 0)
    if (lon is None and slice_type
            in ['3d', "2d", "horiz", 'surf', 'bottom', 'zonal']):
        lon = rescale_itv((xx.min(), xx.max()), 1.1)
    if (lat is None and slice_type
            in ['3d', "2d", "horiz", 'surf', 'bottom', 'merid']):
        lat = rescale_itv((yy.min(), yy.max()), 1.1)

    # Get the plotter
    if slice_type in ['3d', "2d", "horiz", 'surf', 'bottom']:  # map

        # Map
        if plotter is None:
            plotter = create_map(lon,
                                 lat,
                                 level=level,
                                 bathy=bathy,
                                 add_bathy=add_bathy,
                                 fig=fig,
                                 axes=ax,
                                 **kwmap)
        ax = plotter.axes

        # Projection
        xx, yy = plotter(xx, yy)

    else:  # sections

        if plotter is None:

            # Base plot

            kwsection.update(fig=fig, axes=ax, show=False, close=False)
            if add_minimap:
                dict_check_defaults(kwsection, top=.9, right=.9)

            if slice_type == 'merid':
                plotter = section(data=None,
                                  xaxis=MV2.array(lat, id='lat'),
                                  yaxis=MV2.array(level, id='dep'),
                                  **kwsection)
            else:

                plotter = section(data=None,
                                  xaxis=MV2.array(lon, id='lon'),
                                  yaxis=MV2.array(level, id='dep'),
                                  **kwsection)

        ax = plotter.axes

        # Add minimap
        if add_minimap:

            if slice_type == 'merid':
                xlim = interval
                ylim = plotter.axes.get_xlim()
            else:
                xlim = plotter.axes.get_xlim()
                ylim = interval

            extents = dict(x=xlim, y=ylim)
            dict_check_defaults(kwminimap,
                                map_square=True,
                                map_zoom=.5,
                                map_res=None,
                                map_arcgisimage="ocean",
                                map_epsg=3395,
                                linewidth=.6)
            kwminimap['map_fig'] = ax.figure
            add_map_box((xx, yy), extents, **kwminimap)

        # Bathy profile
        if add_section_bathy and bathy is not None or secbathy is not None:
            if secbathy is None:  # interpolate
                if slice_type == 'merid':
                    secbathy = transect(
                        bathy, [0.5 * (interval[0] + interval[1])] * 2,
                        ylim,
                        outaxis='lat')
                else:
                    secbathy = transect(
                        bathy,
                        xlim, [0.5 * (interval[0] + interval[1])] * 2,
                        outaxis='lon')
            tx = secbathy.getAxis(0)[:]
            tb = secbathy.asma()
            axis_bounds = ax.axis()
            dict_check_defaults(kwsecbat, facecolor="0.7")
            ax.fill_between(tx, ax.get_ylim()[0], tb, **kwsecbat)
            ax.axis(axis_bounds)

    axis_bounds = ax.axis()

    # Plot params for scatter
    kwargs.update(linewidth=linewidth, s=size, edgecolor=edgecolor)

    # Data kwargs
    if data is not None:
        dict_check_defaults(kwargs, vmin=data.min(), vmax=data.max())

    # 3D
    pp = []
    if slice_type == "3d":

        kwargs['depthshade'] = depthshade

        # Depth labels
        zfmtfunc = lambda x, pos: deplab(x, nosign=True)
        ax.zaxis.set_major_formatter(FuncFormatter(zfmtfunc))

        # Scatter plots
        if not profiles:  # fully scattered

            # Points
            if data is not None:
                kwargs['c'] = data
            else:
                kwargs['c'] = color
            pp.append(ax.scatter(xx, yy, depths, label=label, **kwargs))

            # Profile lines
            if add_profile_line:
                for ip, (x, y) in enumerate(zip(xx, yy)):
                    plot_profile_line_3d(ax,
                                         x,
                                         y,
                                         xybathy[ip],
                                         zorder=pp[-1].get_zorder() - 0.01,
                                         **kwpf)

        else:  # profiles

            for ip, (x, y) in enumerate(zip(xx, yy)):

                # Skip fully masked
                if mask[:, ip].all():
                    #                    if warn:
                    #                        sonat_warn('Profile fully masked')
                    continue

                # Points
                if zz.ndim == 2:
                    z = depths[:, ip]
                else:
                    z = zz
                z = N.ma.masked_where(mask[:, ip], z, copy=False)
                if data is not None:
                    kwargs['c'] = data[..., ip]
                else:
                    kwargs['c'] = color
                pp.append(
                    ax.scatter([x] * len(z), [y] * len(z),
                               z,
                               label=label,
                               **kwargs))
                label = '_' + str(label)

                # Profile line
                if add_profile_line:
                    plot_profile_line_3d(ax,
                                         x,
                                         y,
                                         -xybathy[ip],
                                         zorder=pp[-1].get_zorder() - 0.01,
                                         **kwpf)

    # Horizontal
    elif slice_type in ['2d', 'surf', 'bottom', 'horiz']:

        # Barotropic case
        if data is not None and data.ndim != 1:
            data = data.mean(axis=0)

        # Scatter plot
        if data is not None:
            kwargs['c'] = data
        else:
            kwargs['c'] = color
        pp.append(ax.scatter(xx, yy, label=label, **kwargs))
        if pp[-1].norm.vmin is None:
            pass

    # Sections
    else:

        # X axis data
        if slice_type == 'zonal':
            xdata = xx
        else:
            xdata = yy

        # Scatter plots
        if not profiles:  # scattered

            if data is not None:
                kwargs['c'] = data
            else:
                kwargs['c'] = color
            pp.append(ax.scatter(xdata, depths, label=label, **kwargs))

        else:  # profiles

            for ip, x in enumerate(xdata):

                # Skip fully masked
                if mask[:, ip].all():
                    #                    if warn:
                    #                        sonat_warn('Profile fully masked')
                    continue

                # Points
                if depths[:].ndim == 2:
                    z = zz[:, ip]
                else:
                    z = zz
                z = N.ma.masked_where(mask[:, ip], z, copy=False)
                if data is not None:
                    kwargs['c'] = data[:, ip]
                else:
                    kwargs['c'] = color

                pp.append(ax.scatter([x] * len(z), z, label=label, **kwargs))
                label = '_' + str(label)

                # Profile line
                if add_profile_line:
                    plot_profile_line_3d(ax,
                                         x,
                                         -xybathy[ip],
                                         zorder=pp[-1].get_zorder() - 0.01,
                                         **kwpf)

    # Finalise
    ax.axis(axis_bounds)
    if title:
        ax.set_title(title)
    if legend:
        plotter.legend(**kwleg)
    if colorbar and data is not None:
        add_colorbar(plotter, pp, units=units, **kwcb)
    if data is not None and register_sm:
        register_scalar_mappable(ax, pp, units=units)
    register_scatter(ax, pp, label)
    return plotter
Example #6
0
    def create_Dthnc(self, fileout, TimeSeries):
        if '2D' in fileout:
            self.i23d = 2
        else:
            self.i23d = 3

        # create file
        if self.i23d == 3:
            Nlev = self.zz.shape[1]
        else:
            Nlev = 1

        time_Series, nc = create_ncTH(
            fileout, len(self.llon), Nlev, self.ivs,
            np.round((TimeSeries - TimeSeries[0]) * 24 * 3600))

        for n in range(0, len(TimeSeries)):
            tin = create_time(np.ones(len(self.llon) * Nlev) *
                              (TimeSeries[n] + 1),
                              units='days since 1-1-1')

            total = np.zeros(shape=(self.ivs, len(self.llon), Nlev))

            # get tide
            if self.tidal:
                var = self.HC.keys()

                for i, v in enumerate(sorted(var)):
                    # horizontal interpolation
                    tmp = get_tide(self.constidx, self.tfreq, self.HC[v],
                                   np.array(TimeSeries[n]), self.lat0)

                    if self.i23d > 2:  # vertical interpolation
                        tmp = vertical_extrapolation(tmp, self.zz, z0=self.z0)

                    total[i, :, :] = total[i, :, :] + tmp

            if self.residual:
                var = self.res_vars

                for i, v in enumerate(sorted(var)):
                    arri = self.res_file[v][:]
                    if self.i23d > 2:
                        dep = create_depth(arri.getAxis(1)[:])
                        extra = create_axis(N.arange(1), id='member')
                        arri2 = np.tile(arri, [1, 1, 1, 1, 1])
                        arri3 = MV2.array(arri2,
                                          axes=[
                                              extra,
                                              arri.getAxis(0), dep,
                                              arri.getAxis(2),
                                              arri.getAxis(3)
                                          ],
                                          copy=False,
                                          fill_value=1e20)

                        zi = arri.getAxis(1)[:]
                        if np.mean(zi) > 0:
                            zi = zi * -1
                        tb = grid2xy(arri3,
                                     xo=np.tile(self.llon,
                                                [Nlev, 1]).T.flatten(),
                                     yo=np.tile(self.llat,
                                                [Nlev, 1]).T.flatten(),
                                     zo=self.zz.flatten(),
                                     method='linear',
                                     to=tin,
                                     zi=zi)

                    else:
                        tb = grid2xy(arri,
                                     xo=self.llon,
                                     yo=self.llat,
                                     method='linear',
                                     to=tin)

                    if np.any(tb.mask == True):
                        bad = tb.mask == True
                        if len(bad.shape) > 1:
                            bad = bad[0, :]
                        tin_bad = create_time(np.ones(len(bad)) *
                                              (TimeSeries[n] + 1),
                                              units='days since 1-1-1')

                        if self.i23d > 2:
                            llon = np.tile(self.llon, [Nlev, 1]).T.flatten()
                            llat = np.tile(self.llat, [Nlev, 1]).T.flatten()
                            zz = self.zz.flatten()
                            zi = arri.getAxis(1)[:]
                            if np.mean(zi) > 0:
                                zi = zi * -1

                            tb[0, bad] = grid2xy(arri3,
                                                 xo=llon[bad],
                                                 yo=llat[bad],
                                                 zo=zz[bad],
                                                 method='nearest',
                                                 to=tin_bad,
                                                 zi=zi)

                        else:
                            tb[bad] = grid2xy(
                                arri,
                                xo=np.array(self.llon)[bad].tolist(),
                                yo=np.array(self.llat)[bad].tolist(),
                                method='nearest',
                                to=tin_bad)

                    if np.any(tb.mask == True):
                        print('probleme')

                    total[i, :, :] = total[i, :, :] + np.reshape(
                        tb, (len(self.llon), Nlev))

            total = np.transpose(total, (1, 2, 0))

            if np.isnan(total).any():
                import pdb
                pdb.set_trace()

            if n % 100 == 0:
                self.logger.info(
                    'For timestep=%.f, max=%.4f, min=%.4f , max abs diff=%.4f'
                    % (TimeSeries[n], total.max(), total.min(),
                       abs(np.diff(total, n=1, axis=0)).max()))

            time_Series[n, :, :, :] = total

        nc.close()
Example #7
0
data = N.resize(lat[:], (ne, nt, nz, nx, ny)) # function of y
data = N.moveaxis(data, -1, -2)
#data = N.arange(nx*ny*nz*nt*ne, dtype='d').reshape(ne, nt, nz, ny, nx)
vi = MV2.array(data,
                 axes=[extra, time, dep, lat, lon], copy=False,
                 fill_value=1e20)
N.random.seed(0)
xo = N.random.uniform(lon0, lon1, np)
yo = N.random.uniform(lat0, lat1, np)
zo = N.random.uniform(dep0, dep1, np)
to = comptime(N.random.uniform(reltime(time0, time.units).value,
                      reltime(time1, time.units).value, np),
                      time.units)

# Rectangular xyzt with 1d z
vo = grid2xy(vi, xo=xo, yo=yo, zo=zo, to=to, method='linear')
von = grid2xy(vi, xo=xo, yo=yo, zo=zo, to=to, method='nearest')
assert vo.shape==(ne, np)
N.testing.assert_allclose(vo[0], yo)
kwp = dict(vmin=vi.min(), vmax=vi.max())
P.figure(figsize=(6, 3))
P.subplot(121)
P.scatter(xo, yo, c=vo[0],  cmap='jet', **kwp)
add_grid(vi.getGrid())
P.title('linear4d')
P.subplot(122)
P.scatter(xo, yo, c=von[0], cmap='jet', **kwp)
add_grid(vi.getGrid())
P.title('nearest4d')
P.figtext(.5, .98, 'grid2xy in 4D', va='top', ha='center', weight='bold')
P.tight_layout()
Example #8
0
def extrapolate_amp(lon, lat, lonr, latr, maskr, arr):
    arr = missing_interp(lon, lat, arr)
    arri = grid2xy(arr, xo=lonr, yo=latr)

    return arri
Example #9
0
def slice_gridded_var(var, member=None, time=None, depth=None, lat=None, lon=None):
    """Make slices of a variable and squeeze out singletons to reduce it

    The "member" axis is considered here as a generic name for the first
    axis of unkown type.

    .. warning:: All axes must be 1D
    """

    # Check order
    var = var(squeeze=1)
    order = var.getOrder()

    # Unkown axis
    if '-' in order and member is not None:
        i = order.find('-')
        id = var.getAxisIds()[i]
        if isinstance(member, slice):
            kw = {id:member}
            var = var(**kw)
        else:
            axo = create_axis(member)
            cp_atts(var.getAxis(i), axo)
            var = regrid1d(var, axo, iaxi=i)(squeeze=N.isscalar(member))

    # Time interpolation
    if 't' in order and time is not None:
        axi = var.getTime()
        if isinstance(time, slice):
            var = var(time=time)
        else:
            axo = create_time(time, axi.units)
            var = regrid1d(var, axo)(squeeze=N.isscalar(time))

    # Depth interpolation
    if 'z' in order and depth is not None:
        if depth=='bottom':
            var = slice_bottom(var)
        else:
            if depth=='surf':
                depth = slice(-1, None)
            if isinstance(depth, slice):
                var = var(level=depth, squeeze=1) # z squeeze only?
            elif (N.isscalar(depth) and var.getLevel()[:].ndim==1 and
                  depth in var.getLevel()):
                var = var(level=depth)
            else:
                axo = create_dep(depth)
                if axo[:].max()>10:
                    sonat_warn('Interpolation depth is positive. Taking this opposite')
                    axo[:] *=-1
                var = regrid1d(var, axo)(squeeze=N.isscalar(depth))

    # Point
    if (order.endswith('yx') and lon is not None and lat is not None and
            not isinstance(lat, slice) and not isinstance(lon, slice)):

        var = grid2xy(var, lon, lat)(squeeze=N.isscalar(lon))

    else:

        # Latitude interpolation
        if 'y' in order and lat:
            if isinstance(lat, slice):
                var = var(lat=lat)
            else:
                axo = create_lat(lat)
                var = regrid1d(var, axo)(squeeze=N.isscalar(lat))

        # Longitude interpolation
        if 'x' in order and lon:
            if isinstance(lon, slice):
                var = var(lon=lon)
            else:
                axo = create_lon(lon)
                var = regrid1d(var, axo)(squeeze=N.isscalar(lon))

    return var
Example #10
0
def get_file_interpolator(sourcefile, var_res, lonr, latr):
    sourcedata = Dataset(sourcefile, 'r')
    sourcedata_dms2 = cdms2.open(sourcefile)
    f_out = {}
    LONR = {}
    LATR = {}
    time0 = [
        t.torelative('days since 1-1-1').value
        for t in sourcedata_dms2['time'].asRelativeTime()
    ]
    tres = np.array(time0)

    for varname in var_res:
        #	arri = griddata(lat.ravel(),lon.ravel(), tempf[0,:,:].compressed(), tempf.getGrid(), method='nearest')
        lon, lat, arri = missing_interp(sourcedata_dms2[varname].getAxis(2)[:],
                                        sourcedata_dms2[varname].getAxis(1)[:],
                                        sourcedata_dms2[varname][:])
        import pdb
        pdb.set_trace()
        tt = create_time(time0, units='days since 1-1-1')
        tb = grid2xy(arri, xo=lonr, yo=latr, method='nearest', to=tt)

        Res_val = np.ma.masked_where(tempf >= 1e36, tempf)
        if len(Res_val.shape) > 3:
            LONR[varname] = []
            LATR[varname] = []
            f_out[varname] = []

            for all_l in range(0, Res_val.shape[1]):

                test = Res_val[:, all_l, :, :].reshape(
                    Res_val.shape[0], Res_val.shape[2] * Res_val.shape[3])
                gd_node = test[0, :].nonzero()[0]
                test = np.array(test)
                test = np.ma.masked_where(test == Res_val.fill_value, test)
                test = test[:, gd_node]
                if test.shape[1] > 0:
                    gd_ts = test[:, int(test.shape[1] / 2)].nonzero()[0]
                    f_out[varname].append(
                        interp1d(tres[gd_ts],
                                 test[gd_ts],
                                 axis=0,
                                 fill_value=np.nan))

                    LONR[varname].append(lonr[gd_node])
                    LATR[varname].append(latr[gd_node])
        else:
            test = Res_val[:, :, :].reshape(
                Res_val.shape[0], Res_val.shape[1] * Res_val.shape[2])
            gd_node = test[0, :].nonzero()[0]
            test = np.array(test)
            test = np.ma.masked_where(test >= 1e20, test)
            test = test[:, gd_node]
            gd_ts = test[:, int(test.shape[1] / 2)].nonzero()[0]
            LONR[varname] = lonr[gd_node]
            LATR[varname] = latr[gd_node]
            f_out[varname] = interp1d(tres[gd_ts],
                                      test[gd_ts],
                                      axis=0,
                                      fill_value=np.nan)  #'extrapolate')

    return f_out, LONR, LATR